diff --git a/bigframes/core/compile/sqlglot/scalar_compiler.py b/bigframes/core/compile/sqlglot/scalar_compiler.py index 3e12da6d92..8167f40fc3 100644 --- a/bigframes/core/compile/sqlglot/scalar_compiler.py +++ b/bigframes/core/compile/sqlglot/scalar_compiler.py @@ -79,7 +79,7 @@ def register_unary_op( """ key = typing.cast(str, op_ref.name) - def decorator(impl: typing.Callable[..., TypedExpr]): + def decorator(impl: typing.Callable[..., sge.Expression]): def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp): if pass_op: return impl(args[0], op) @@ -108,7 +108,7 @@ def register_binary_op( """ key = typing.cast(str, op_ref.name) - def decorator(impl: typing.Callable[..., TypedExpr]): + def decorator(impl: typing.Callable[..., sge.Expression]): def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp): if pass_op: return impl(args[0], args[1], op) @@ -132,7 +132,7 @@ def register_ternary_op( """ key = typing.cast(str, op_ref.name) - def decorator(impl: typing.Callable[..., TypedExpr]): + def decorator(impl: typing.Callable[..., sge.Expression]): def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp): return impl(args[0], args[1], args[2]) @@ -156,7 +156,7 @@ def register_nary_op( """ key = typing.cast(str, op_ref.name) - def decorator(impl: typing.Callable[..., TypedExpr]): + def decorator(impl: typing.Callable[..., sge.Expression]): def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp): if pass_op: return impl(*args, op=op)