11from ast import *
2- from utils import CProgram , Goto , trace , Bottom , Let
2+ from utils import CProgram , Goto , trace , Bottom , IntType , BoolType , Begin
33import copy
44
55class TypeCheckCif :
@@ -11,14 +11,23 @@ def check_type_equal(self, t1, t2, e):
1111 raise Exception ('error: ' + repr (t1 ) + ' != ' + repr (t2 ) \
1212 + ' in ' + repr (e ))
1313
14+ def combine_types (self , t1 , t2 ):
15+ match (t1 , t2 ):
16+ case (Bottom (), _):
17+ return t2
18+ case (_, Bottom ()):
19+ return t1
20+ case _:
21+ return t1
22+
1423 def type_check_atm (self , e , env ):
1524 match e :
1625 case Name (id ):
1726 return env .get (id , Bottom ())
1827 case Constant (value ) if isinstance (value , bool ):
19- return bool
28+ return BoolType ()
2029 case Constant (value ) if isinstance (value , int ):
21- return int
30+ return IntType ()
2231 case _:
2332 raise Exception ('error in type_check_atm, unexpected ' + repr (e ))
2433
@@ -30,43 +39,47 @@ def type_check_exp(self, e, env):
3039 return self .type_check_atm (e , env )
3140 case IfExp (test , body , orelse ):
3241 test_t = self .type_check_exp (test , env )
33- self .check_type_equal (bool , test_t , test )
42+ self .check_type_equal (BoolType () , test_t , test )
3443 body_t = self .type_check_exp (body , env )
3544 orelse_t = self .type_check_exp (orelse , env )
3645 self .check_type_equal (body_t , orelse_t , e )
3746 return body_t
3847 case BinOp (left , op , right ) if isinstance (op , Add ) or isinstance (op , Sub ):
3948 l = self .type_check_atm (left , env )
40- self .check_type_equal (l , int , e )
49+ self .check_type_equal (l , IntType () , e )
4150 r = self .type_check_atm (right , env )
42- self .check_type_equal (r , int , e )
43- return int
51+ self .check_type_equal (r , IntType () , e )
52+ return IntType ()
4453 case UnaryOp (USub (), v ):
4554 t = self .type_check_atm (v , env )
46- self .check_type_equal (t , int , e )
47- return int
55+ self .check_type_equal (t , IntType () , e )
56+ return IntType ()
4857 case UnaryOp (Not (), v ):
4958 t = self .type_check_exp (v , env )
50- self .check_type_equal (t , bool , e )
51- return bool
52- case Compare (left , [cmp ], [right ]) if isinstance (cmp , Eq ) or isinstance (cmp , NotEq ):
59+ self .check_type_equal (t , BoolType (), e )
60+ return BoolType ()
61+ case Compare (left , [cmp ], [right ]) if isinstance (cmp , Eq ) \
62+ or isinstance (cmp , NotEq ):
5363 l = self .type_check_atm (left , env )
5464 r = self .type_check_atm (right , env )
5565 self .check_type_equal (l , r , e )
56- return bool
66+ return BoolType ()
5767 case Compare (left , [cmp ], [right ]):
5868 l = self .type_check_atm (left , env )
59- self .check_type_equal (l , int , left )
69+ self .check_type_equal (l , IntType () , left )
6070 r = self .type_check_atm (right , env )
61- self .check_type_equal (r , int , right )
62- return bool
71+ self .check_type_equal (r , IntType () , right )
72+ return BoolType ()
6373 case Call (Name ('input_int' ), []):
64- return int
65- case Let (Name (x ), rhs , body ):
66- t = self .type_check_exp (rhs , env )
67- new_env = dict (env )
68- new_env [x ] = t
69- return self .type_check_exp (body , new_env )
74+ return IntType ()
75+ # case Let(Name(x), rhs, body):
76+ # t = self.type_check_exp(rhs, env)
77+ # new_env = dict(env)
78+ # new_env[x] = t
79+ # return self.type_check_exp(body, new_env)
80+ case Begin (ss , e ):
81+ self .type_check_stmts (ss , env )
82+ return self .type_check_exp (e , env )
7083 case _:
7184 raise Exception ('error in type_check_exp, unexpected ' + repr (e ))
7285
@@ -79,12 +92,14 @@ def type_check_stmt(self, s, env):
7992 case Assign ([lhs ], value ):
8093 t = self .type_check_exp (value , env )
8194 if lhs .id in env :
82- self .check_type_equal (env .get (lhs .id , Bottom ()), t , s )
95+ lhs_ty = env .get (lhs .id , Bottom ())
96+ self .check_type_equal (lhs_ty , t , s )
97+ env [lhs .id ] = self .combine_types (t , lhs_ty )
8398 else :
8499 env [lhs .id ] = t
85100 case Expr (Call (Name ('print' ), [arg ])):
86101 t = self .type_check_exp (arg , env )
87- self .check_type_equal (t , int , s )
102+ self .check_type_equal (t , IntType () , s )
88103 case Expr (value ):
89104 self .type_check_exp (value , env )
90105 case If (Compare (left , [cmp ], [right ]), body , orelse ):
0 commit comments