@@ -79,7 +79,7 @@ def register_unary_op(
7979 """
8080 key = typing .cast (str , op_ref .name )
8181
82- def decorator (impl : typing .Callable [..., TypedExpr ]):
82+ def decorator (impl : typing .Callable [..., sge . Expression ]):
8383 def normalized_impl (args : typing .Sequence [TypedExpr ], op : ops .RowOp ):
8484 if pass_op :
8585 return impl (args [0 ], op )
@@ -108,7 +108,7 @@ def register_binary_op(
108108 """
109109 key = typing .cast (str , op_ref .name )
110110
111- def decorator (impl : typing .Callable [..., TypedExpr ]):
111+ def decorator (impl : typing .Callable [..., sge . Expression ]):
112112 def normalized_impl (args : typing .Sequence [TypedExpr ], op : ops .RowOp ):
113113 if pass_op :
114114 return impl (args [0 ], args [1 ], op )
@@ -132,7 +132,7 @@ def register_ternary_op(
132132 """
133133 key = typing .cast (str , op_ref .name )
134134
135- def decorator (impl : typing .Callable [..., TypedExpr ]):
135+ def decorator (impl : typing .Callable [..., sge . Expression ]):
136136 def normalized_impl (args : typing .Sequence [TypedExpr ], op : ops .RowOp ):
137137 return impl (args [0 ], args [1 ], args [2 ])
138138
@@ -156,7 +156,7 @@ def register_nary_op(
156156 """
157157 key = typing .cast (str , op_ref .name )
158158
159- def decorator (impl : typing .Callable [..., TypedExpr ]):
159+ def decorator (impl : typing .Callable [..., sge . Expression ]):
160160 def normalized_impl (args : typing .Sequence [TypedExpr ], op : ops .RowOp ):
161161 if pass_op :
162162 return impl (* args , op = op )
0 commit comments