@@ -23,7 +23,9 @@ def combine_types(self, t1, t2):
2323 def type_check_atm (self , e , env ):
2424 match e :
2525 case Name (id ):
26- return env .get (id , Bottom ())
26+ t = env .get (id , Bottom ())
27+ env [id ] = t # make sure this gets into the environment for later definedness checking
28+ return t
2729 case Constant (value ) if isinstance (value , bool ):
2830 return BoolType ()
2931 case Constant (value ) if isinstance (value , int ):
@@ -37,13 +39,6 @@ def type_check_exp(self, e, env):
3739 return self .type_check_atm (e , env )
3840 case Constant (value ):
3941 return self .type_check_atm (e , env )
40- case IfExp (test , body , orelse ):
41- test_t = self .type_check_exp (test , env )
42- self .check_type_equal (BoolType (), test_t , test )
43- body_t = self .type_check_exp (body , env )
44- orelse_t = self .type_check_exp (orelse , env )
45- self .check_type_equal (body_t , orelse_t , e )
46- return body_t
4742 case BinOp (left , op , right ) if isinstance (op , Add ) or isinstance (op , Sub ):
4843 l = self .type_check_atm (left , env )
4944 self .check_type_equal (l , IntType (), e )
@@ -72,11 +67,6 @@ def type_check_exp(self, e, env):
7267 return BoolType ()
7368 case Call (Name ('input_int' ), []):
7469 return IntType ()
75- # case Let(Name(x), rhs, body):
76- # t = self.type_check_exp(rhs, env)
77- # new_env = dict(env)
78- # new_env[x] = t
79- # return self.type_check_exp(body, new_env)
8070 case Begin (ss , e ):
8171 self .type_check_stmts (ss , env )
8272 return self .type_check_exp (e , env )
@@ -91,40 +81,44 @@ def type_check_stmt(self, s, env):
9181 match s :
9282 case Assign ([lhs ], value ):
9383 t = self .type_check_exp (value , env )
94- if lhs .id in env :
95- lhs_ty = env .get (lhs .id , Bottom ())
96- self .check_type_equal (lhs_ty , t , s )
97- env [lhs .id ] = self .combine_types (t , lhs_ty )
98- else :
99- env [lhs .id ] = t
84+ lhs_ty = env .get (lhs .id , Bottom ())
85+ self .check_type_equal (lhs_ty , t , s )
86+ env [lhs .id ] = self .combine_types (t , lhs_ty )
10087 case Expr (Call (Name ('print' ), [arg ])):
10188 t = self .type_check_exp (arg , env )
10289 self .check_type_equal (t , IntType (), s )
10390 case Expr (value ):
10491 self .type_check_exp (value , env )
105- case If (Compare (left , [cmp ], [right ]), body , orelse ):
92+ case _:
93+ raise Exception ('error in type_check_stmt, unexpected ' + repr (s ))
94+
95+ def type_check_tail (self , s , env ):
96+ match s :
97+ case If (Compare (left , [cmp ], [right ]), [Goto (_)], [Goto (_)]):
10698 left_t = self .type_check_atm (left , env )
10799 right_t = self .type_check_atm (right , env )
108100 self .check_type_equal (left_t , right_t , s ) # not quite strict enough
109- self .type_check_stmts (body , env )
110- self .type_check_stmts (orelse , env )
111101 case Goto (label ):
112102 pass
113103 case Return (value ):
114104 value_t = self .type_check_exp (value , env )
115105 case _:
116- raise Exception ('error in type_check_stmt , unexpected ' + repr (s ))
117-
106+ raise Exception ('error in type_check_tail , unexpected' + repr (s ))
107+
118108 def type_check (self , p ):
119109 match p :
120110 case CProgram (body ):
121111 env = {}
122112 while True :
123113 old_env = copy .deepcopy (env )
124114 for (l , ss ) in body .items ():
125- self .type_check_stmts (ss , env )
115+ self .type_check_stmts (ss [:- 1 ], env )
116+ self .type_check_tail (ss [- 1 ], env )
126117 if env == old_env :
127118 break
119+ undefs = [x for x ,t in env .items () if t == Bottom ()]
120+ if undefs :
121+ raise Exception ('error: undefined type for ' + str (undefs ))
128122 p .var_types = env
129123 case _:
130124 raise Exception ('error in type_check, unexpected ' + repr (p ))
0 commit comments