From 7e7b0cde880c7933c8b80dbd201884d37a89a5ef Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 11 Sep 2025 22:20:27 +0000 Subject: [PATCH 1/3] refactor: reorganize the sqlglot scalar compiler layout - part 1 --- bigframes/core/compile/sqlglot/__init__.py | 2 + .../compile/sqlglot/aggregate_compiler.py | 8 +- .../compile/sqlglot/aggregations/windows.py | 7 +- bigframes/core/compile/sqlglot/compiler.py | 24 +- .../compile/sqlglot/expressions/__init__.py | 8 + .../sqlglot/expressions/binary_compiler.py | 65 ++- .../sqlglot/expressions/nary_compiler.py | 27 -- .../sqlglot/expressions/op_registration.py | 54 --- .../sqlglot/expressions/ternary_compiler.py | 29 -- .../sqlglot/expressions/unary_compiler.py | 448 +++++++++--------- .../core/compile/sqlglot/scalar_compiler.py | 207 ++++++-- .../expressions/test_op_registration.py | 43 -- 12 files changed, 445 insertions(+), 477 deletions(-) delete mode 100644 bigframes/core/compile/sqlglot/expressions/nary_compiler.py delete mode 100644 bigframes/core/compile/sqlglot/expressions/op_registration.py delete mode 100644 bigframes/core/compile/sqlglot/expressions/ternary_compiler.py delete mode 100644 tests/unit/core/compile/sqlglot/expressions/test_op_registration.py diff --git a/bigframes/core/compile/sqlglot/__init__.py b/bigframes/core/compile/sqlglot/__init__.py index 2f40894975..8a1172b704 100644 --- a/bigframes/core/compile/sqlglot/__init__.py +++ b/bigframes/core/compile/sqlglot/__init__.py @@ -14,5 +14,7 @@ from __future__ import annotations from bigframes.core.compile.sqlglot.compiler import SQLGlotCompiler +import bigframes.core.compile.sqlglot.expressions.binary_compiler # noqa: F401 +import bigframes.core.compile.sqlglot.expressions.unary_compiler # noqa: F401 __all__ = ["SQLGlotCompiler"] diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py index ccfba1ce0f..08bca535a8 100644 --- a/bigframes/core/compile/sqlglot/aggregate_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -35,7 +35,7 @@ def compile_aggregate( return nullary_compiler.compile(aggregate.op) if isinstance(aggregate, agg_expressions.UnaryAggregation): column = typed_expr.TypedExpr( - scalar_compiler.compile_scalar_expression(aggregate.arg), + scalar_compiler.scalar_op_compiler.compile_expression(aggregate.arg), aggregate.arg.output_type, ) if not aggregate.op.order_independent: @@ -46,11 +46,11 @@ def compile_aggregate( return unary_compiler.compile(aggregate.op, column) elif isinstance(aggregate, agg_expressions.BinaryAggregation): left = typed_expr.TypedExpr( - scalar_compiler.compile_scalar_expression(aggregate.left), + scalar_compiler.scalar_op_compiler.compile_expression(aggregate.left), aggregate.left.output_type, ) right = typed_expr.TypedExpr( - scalar_compiler.compile_scalar_expression(aggregate.right), + scalar_compiler.scalar_op_compiler.compile_expression(aggregate.right), aggregate.right.output_type, ) return binary_compiler.compile(aggregate.op, left, right) @@ -66,7 +66,7 @@ def compile_analytic( return nullary_compiler.compile(aggregate.op) if isinstance(aggregate, agg_expressions.UnaryAggregation): column = typed_expr.TypedExpr( - scalar_compiler.compile_scalar_expression(aggregate.arg), + scalar_compiler.scalar_op_compiler.compile_expression(aggregate.arg), aggregate.arg.output_type, ) return unary_compiler.compile(aggregate.op, column, window) diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index 47fd43bd08..4d7a3f7406 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -51,7 +51,10 @@ def apply_window_if_present( order = sge.Order(expressions=order_by) if order_by else None group_by = ( - [scalar_compiler.compile_scalar_expression(key) for key in window.grouping_keys] + [ + scalar_compiler.scalar_op_compiler.compile_expression(key) + for key in window.grouping_keys + ] if window.grouping_keys else None ) @@ -101,7 +104,7 @@ def get_window_order_by( order_by = [] for ordering_spec_item in ordering: - expr = scalar_compiler.compile_scalar_expression( + expr = scalar_compiler.scalar_op_compiler.compile_expression( ordering_spec_item.scalar_expression ) desc = not ordering_spec_item.direction.is_ascending diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 8364e757a1..ba2c644689 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -131,7 +131,7 @@ def _compile_result_node(self, root: nodes.ResultNode) -> str: # Have to bind schema as the final step before compilation. root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root)) selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( - (name, scalar_compiler.compile_scalar_expression(ref)) + (name, scalar_compiler.scalar_op_compiler.compile_expression(ref)) for ref, name in root.output_cols ) sqlglot_ir = self.compile_node(root.child).select(selected_cols) @@ -139,7 +139,7 @@ def _compile_result_node(self, root: nodes.ResultNode) -> str: if root.order_by is not None: ordering_cols = tuple( sge.Ordered( - this=scalar_compiler.compile_scalar_expression( + this=scalar_compiler.scalar_op_compiler.compile_expression( ordering.scalar_expression ), desc=ordering.direction.is_ascending is False, @@ -199,7 +199,7 @@ def compile_selection( self, node: nodes.SelectionNode, child: ir.SQLGlotIR ) -> ir.SQLGlotIR: selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( - (id.sql, scalar_compiler.compile_scalar_expression(expr)) + (id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr)) for expr, id in node.input_output_pairs ) return child.select(selected_cols) @@ -209,7 +209,7 @@ def compile_projection( self, node: nodes.ProjectionNode, child: ir.SQLGlotIR ) -> ir.SQLGlotIR: projected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( - (id.sql, scalar_compiler.compile_scalar_expression(expr)) + (id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr)) for expr, id in node.assignments ) return child.project(projected_cols) @@ -218,7 +218,9 @@ def compile_projection( def compile_filter( self, node: nodes.FilterNode, child: ir.SQLGlotIR ) -> ir.SQLGlotIR: - condition = scalar_compiler.compile_scalar_expression(node.predicate) + condition = scalar_compiler.scalar_op_compiler.compile_expression( + node.predicate + ) return child.filter(tuple([condition])) @_compile_node.register @@ -228,10 +230,12 @@ def compile_join( conditions = tuple( ( typed_expr.TypedExpr( - scalar_compiler.compile_scalar_expression(left), left.output_type + scalar_compiler.scalar_op_compiler.compile_expression(left), + left.output_type, ), typed_expr.TypedExpr( - scalar_compiler.compile_scalar_expression(right), right.output_type + scalar_compiler.scalar_op_compiler.compile_expression(right), + right.output_type, ), ) for left, right in node.conditions @@ -308,7 +312,7 @@ def compile_aggregate( for agg, id in node.aggregations ) by_cols: tuple[sge.Expression, ...] = tuple( - scalar_compiler.compile_scalar_expression(by_col) + scalar_compiler.scalar_op_compiler.compile_expression(by_col) for by_col in node.by_column_ids ) @@ -332,7 +336,9 @@ def compile_window( window_op = aggregate_compiler.compile_analytic(node.expression, window_spec) inputs: tuple[sge.Expression, ...] = tuple( - scalar_compiler.compile_scalar_expression(expression.DerefOp(column)) + scalar_compiler.scalar_op_compiler.compile_expression( + expression.DerefOp(column) + ) for column in node.expression.column_references ) diff --git a/bigframes/core/compile/sqlglot/expressions/__init__.py b/bigframes/core/compile/sqlglot/expressions/__init__.py index 0a2669d7a2..f42d5c7d99 100644 --- a/bigframes/core/compile/sqlglot/expressions/__init__.py +++ b/bigframes/core/compile/sqlglot/expressions/__init__.py @@ -11,3 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Expression implementations for the SQLGlot-based compiler. + +This directory structure should reflect the same layout as the +`bigframes/operations` directory where the expressions are defined. + +Prefer a few ops per file to keep file sizes manageable for text editors and LLMs. +""" diff --git a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py index 3fcba04cfd..b18d15cae6 100644 --- a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py @@ -20,19 +20,16 @@ from bigframes import dtypes from bigframes import operations as ops import bigframes.core.compile.sqlglot.expressions.constants as constants -from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr +import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler -BINARY_OP_REGISTRATION = OpRegistration() +register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op - -def compile(op: ops.BinaryOp, left: TypedExpr, right: TypedExpr) -> sge.Expression: - return BINARY_OP_REGISTRATION[op](op, left, right) +# TODO: add parenthesize for operators -# TODO: add parenthesize for operators -@BINARY_OP_REGISTRATION.register(ops.add_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.add_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE: # String addition return sge.Concat(expressions=[left.expr, right.expr]) @@ -66,15 +63,15 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: ) -@BINARY_OP_REGISTRATION.register(ops.eq_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.eq_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.EQ(this=left_expr, expression=right_expr) -@BINARY_OP_REGISTRATION.register(ops.eq_null_match_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.eq_null_match_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = left.expr if right.dtype != dtypes.BOOL_DTYPE: left_expr = _coerce_bool_to_int(left) @@ -93,8 +90,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.EQ(this=left_coalesce, expression=right_coalesce) -@BINARY_OP_REGISTRATION.register(ops.div_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.div_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) @@ -105,8 +102,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return result -@BINARY_OP_REGISTRATION.register(ops.floordiv_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.floordiv_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) @@ -138,41 +135,41 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return result -@BINARY_OP_REGISTRATION.register(ops.ge_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.ge_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.GTE(this=left_expr, expression=right_expr) -@BINARY_OP_REGISTRATION.register(ops.gt_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.gt_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.GT(this=left_expr, expression=right_expr) -@BINARY_OP_REGISTRATION.register(ops.JSONSet) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.JSONSet, pass_op=True) +def _(left: TypedExpr, right: TypedExpr, op) -> sge.Expression: return sge.func("JSON_SET", left.expr, sge.convert(op.json_path), right.expr) -@BINARY_OP_REGISTRATION.register(ops.lt_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.lt_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.LT(this=left_expr, expression=right_expr) -@BINARY_OP_REGISTRATION.register(ops.le_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.le_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.LTE(this=left_expr, expression=right_expr) -@BINARY_OP_REGISTRATION.register(ops.mul_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.mul_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) @@ -186,20 +183,20 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return result -@BINARY_OP_REGISTRATION.register(ops.ne_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.ne_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.NEQ(this=left_expr, expression=right_expr) -@BINARY_OP_REGISTRATION.register(ops.obj_make_ref_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.obj_make_ref_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.func("OBJ.MAKE_REF", left.expr, right.expr) -@BINARY_OP_REGISTRATION.register(ops.sub_op) -def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: +@register_binary_op(ops.sub_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) diff --git a/bigframes/core/compile/sqlglot/expressions/nary_compiler.py b/bigframes/core/compile/sqlglot/expressions/nary_compiler.py deleted file mode 100644 index 12f68613d7..0000000000 --- a/bigframes/core/compile/sqlglot/expressions/nary_compiler.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import sqlglot.expressions as sge - -from bigframes import operations as ops -from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration -from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr - -NARY_OP_REGISTRATION = OpRegistration() - - -def compile(op: ops.NaryOp, *args: TypedExpr) -> sge.Expression: - return NARY_OP_REGISTRATION[op](op, *args) diff --git a/bigframes/core/compile/sqlglot/expressions/op_registration.py b/bigframes/core/compile/sqlglot/expressions/op_registration.py deleted file mode 100644 index d5e4853a45..0000000000 --- a/bigframes/core/compile/sqlglot/expressions/op_registration.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import typing - -from sqlglot import expressions as sge - -from bigframes import operations as ops - -# We should've been more specific about input types. Unfortunately, -# MyPy doesn't support more rigorous checks. -CompilationFunc = typing.Callable[..., sge.Expression] - - -class OpRegistration: - def __init__(self) -> None: - self._registered_ops: dict[str, CompilationFunc] = {} - - def register( - self, op: ops.ScalarOp | type[ops.ScalarOp] - ) -> typing.Callable[[CompilationFunc], CompilationFunc]: - def decorator(item: CompilationFunc): - def arg_checker(*args, **kwargs): - if not isinstance(args[0], ops.ScalarOp): - raise ValueError( - f"The first parameter must be an operator. Got {type(args[0])}" - ) - return item(*args, **kwargs) - - key = typing.cast(str, op.name) - if key in self._registered_ops: - raise ValueError(f"{key} is already registered") - self._registered_ops[key] = item - return arg_checker - - return decorator - - def __getitem__(self, op: str | ops.ScalarOp) -> CompilationFunc: - if isinstance(op, ops.ScalarOp): - return self._registered_ops[op.name] - return self._registered_ops[op] diff --git a/bigframes/core/compile/sqlglot/expressions/ternary_compiler.py b/bigframes/core/compile/sqlglot/expressions/ternary_compiler.py deleted file mode 100644 index 9b00771f7d..0000000000 --- a/bigframes/core/compile/sqlglot/expressions/ternary_compiler.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import sqlglot.expressions as sge - -from bigframes import operations as ops -from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration -from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr - -TERNATRY_OP_REGISTRATION = OpRegistration() - - -def compile( - op: ops.TernaryOp, expr1: TypedExpr, expr2: TypedExpr, expr3: TypedExpr -) -> sge.Expression: - return TERNATRY_OP_REGISTRATION[op](op, expr1, expr2, expr3) diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index f519aef70d..d93b1e681c 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -25,24 +25,20 @@ from bigframes import operations as ops from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS import bigframes.core.compile.sqlglot.expressions.constants as constants -from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr +import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler import bigframes.dtypes as dtypes -UNARY_OP_REGISTRATION = OpRegistration() +register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op -def compile(op: ops.UnaryOp, expr: TypedExpr) -> sge.Expression: - return UNARY_OP_REGISTRATION[op](op, expr) - - -@UNARY_OP_REGISTRATION.register(ops.abs_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.abs_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Abs(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.arccosh_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.arccosh_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -54,8 +50,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.arccos_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.arccos_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -67,8 +63,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.arcsin_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.arcsin_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -80,18 +76,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.arcsinh_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.arcsinh_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("ASINH", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.arctan_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.arctan_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("ATAN", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.arctanh_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.arctanh_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -103,19 +99,19 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.AsTypeOp) -def _(op: ops.AsTypeOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.AsTypeOp, pass_op=True) +def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: # TODO: Support more types for casting, such as JSON, etc. return sge.Cast(this=expr.expr, to=op.to_type) -@UNARY_OP_REGISTRATION.register(ops.ArrayToStringOp) -def _(op: ops.ArrayToStringOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ArrayToStringOp, pass_op=True) +def _(expr: TypedExpr, op: ops.ArrayToStringOp) -> sge.Expression: return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'") -@UNARY_OP_REGISTRATION.register(ops.ArrayIndexOp) -def _(op: ops.ArrayIndexOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ArrayIndexOp, pass_op=True) +def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression: return sge.Bracket( this=expr.expr, expressions=[sge.Literal.number(op.index)], @@ -124,8 +120,8 @@ def _(op: ops.ArrayIndexOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.ArraySliceOp) -def _(op: ops.ArraySliceOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ArraySliceOp, pass_op=True) +def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression: slice_idx = sqlglot.to_identifier("slice_idx") conditions: typing.List[sge.Predicate] = [slice_idx >= op.start] @@ -151,23 +147,23 @@ def _(op: ops.ArraySliceOp, expr: TypedExpr) -> sge.Expression: return sge.array(selected_elements) -@UNARY_OP_REGISTRATION.register(ops.capitalize_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.capitalize_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Initcap(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.ceil_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ceil_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Ceil(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.cos_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.cos_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("COS", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.cosh_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.cosh_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -179,25 +175,25 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.StrContainsOp) -def _(op: ops.StrContainsOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrContainsOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrContainsOp) -> sge.Expression: return sge.Like(this=expr.expr, expression=sge.convert(f"%{op.pat}%")) -@UNARY_OP_REGISTRATION.register(ops.StrContainsRegexOp) -def _(op: ops.StrContainsRegexOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrContainsRegexOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrContainsRegexOp) -> sge.Expression: return sge.RegexpLike(this=expr.expr, expression=sge.convert(op.pat)) -@UNARY_OP_REGISTRATION.register(ops.StrExtractOp) -def _(op: ops.StrExtractOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrExtractOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrExtractOp) -> sge.Expression: return sge.RegexpExtract( this=expr.expr, expression=sge.convert(op.pat), group=sge.convert(op.n) ) -@UNARY_OP_REGISTRATION.register(ops.StrFindOp) -def _(op: ops.StrFindOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrFindOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrFindOp) -> sge.Expression: # INSTR is 1-based, so we need to adjust the start position. start = sge.convert(op.start + 1) if op.start is not None else sge.convert(1) if op.end is not None: @@ -220,13 +216,13 @@ def _(op: ops.StrFindOp, expr: TypedExpr) -> sge.Expression: ) - sge.convert(1) -@UNARY_OP_REGISTRATION.register(ops.StrLstripOp) -def _(op: ops.StrLstripOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrLstripOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrLstripOp) -> sge.Expression: return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT") -@UNARY_OP_REGISTRATION.register(ops.StrPadOp) -def _(op: ops.StrPadOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrPadOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrPadOp) -> sge.Expression: pad_length = sge.func( "GREATEST", sge.Length(this=expr.expr), sge.convert(op.length) ) @@ -266,36 +262,36 @@ def _(op: ops.StrPadOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.StrRepeatOp) -def _(op: ops.StrRepeatOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrRepeatOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrRepeatOp) -> sge.Expression: return sge.Repeat(this=expr.expr, times=sge.convert(op.repeats)) -@UNARY_OP_REGISTRATION.register(ops.date_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.date_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Date(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.day_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.day_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="DAY"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.dayofweek_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.dayofweek_op) +def _(expr: TypedExpr) -> sge.Expression: # Adjust the 1-based day-of-week index (from SQL) to a 0-based index. return sge.Extract( this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr ) - sge.convert(1) -@UNARY_OP_REGISTRATION.register(ops.dayofyear_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.dayofyear_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="DAYOFYEAR"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.EndsWithOp) -def _(op: ops.EndsWithOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.EndsWithOp, pass_op=True) +def _(expr: TypedExpr, op: ops.EndsWithOp) -> sge.Expression: if not op.pat: return sge.false() @@ -306,8 +302,8 @@ def to_endswith(pat: str) -> sge.Expression: return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions) -@UNARY_OP_REGISTRATION.register(ops.exp_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.exp_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -319,8 +315,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.expm1_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.expm1_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -332,34 +328,34 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) - sge.convert(1) -@UNARY_OP_REGISTRATION.register(ops.FloorDtOp) -def _(op: ops.FloorDtOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.FloorDtOp, pass_op=True) +def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression: # TODO: Remove this method when it is covered by ops.FloorOp return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=op.freq)) -@UNARY_OP_REGISTRATION.register(ops.floor_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.floor_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Floor(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.geo_area_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.geo_area_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("ST_AREA", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.geo_st_astext_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.geo_st_astext_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("ST_ASTEXT", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.geo_st_boundary_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.geo_st_boundary_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("ST_BOUNDARY", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.GeoStBufferOp) -def _(op: ops.GeoStBufferOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.GeoStBufferOp, pass_op=True) +def _(expr: TypedExpr, op: ops.GeoStBufferOp) -> sge.Expression: return sge.func( "ST_BUFFER", expr.expr, @@ -369,58 +365,58 @@ def _(op: ops.GeoStBufferOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.geo_st_centroid_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.geo_st_centroid_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("ST_CENTROID", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.geo_st_convexhull_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.geo_st_convexhull_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("ST_CONVEXHULL", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.geo_st_geogfromtext_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.geo_st_geogfromtext_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("SAFE.ST_GEOGFROMTEXT", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.geo_st_isclosed_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.geo_st_isclosed_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("ST_ISCLOSED", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.GeoStLengthOp) -def _(op: ops.GeoStLengthOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.GeoStLengthOp, pass_op=True) +def _(expr: TypedExpr, op: ops.GeoStLengthOp) -> sge.Expression: return sge.func("ST_LENGTH", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.geo_x_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.geo_x_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("SAFE.ST_X", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.geo_y_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.geo_y_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("SAFE.ST_Y", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.hash_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.hash_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("FARM_FINGERPRINT", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.hour_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.hour_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="HOUR"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.invert_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.invert_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.BitwiseNot(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.IsInOp) -def _(op: ops.IsInOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.IsInOp, pass_op=True) +def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression: values = [] is_numeric_expr = dtypes.is_numeric(expr.dtype) for value in op.values: @@ -445,28 +441,28 @@ def _(op: ops.IsInOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.isalnum_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.isalnum_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^(\p{N}|\p{L})+$")) -@UNARY_OP_REGISTRATION.register(ops.isalpha_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.isalpha_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\p{L}+$")) -@UNARY_OP_REGISTRATION.register(ops.isdecimal_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.isdecimal_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\d+$")) -@UNARY_OP_REGISTRATION.register(ops.isdigit_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.isdigit_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\p{Nd}+$")) -@UNARY_OP_REGISTRATION.register(ops.islower_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.islower_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.And( this=sge.EQ( this=sge.Lower(this=expr.expr), @@ -479,38 +475,38 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.iso_day_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.iso_day_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.iso_week_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.iso_week_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="ISOWEEK"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.iso_year_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.iso_year_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="ISOYEAR"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.isnull_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.isnull_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Is(this=expr.expr, expression=sge.Null()) -@UNARY_OP_REGISTRATION.register(ops.isnumeric_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.isnumeric_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\pN+$")) -@UNARY_OP_REGISTRATION.register(ops.isspace_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.isspace_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\s+$")) -@UNARY_OP_REGISTRATION.register(ops.isupper_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.isupper_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.And( this=sge.EQ( this=sge.Upper(this=expr.expr), @@ -523,13 +519,13 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.len_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.len_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Length(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.ln_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ln_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -541,8 +537,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.log10_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.log10_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -554,8 +550,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.log1p_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.log1p_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -567,13 +563,13 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.lower_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.lower_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Lower(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.MapOp) -def _(op: ops.MapOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.MapOp, pass_op=True) +def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression: return sge.Case( this=expr.expr, ifs=[ @@ -583,80 +579,80 @@ def _(op: ops.MapOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.minute_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.minute_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="MINUTE"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.month_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.month_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="MONTH"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.neg_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.neg_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Neg(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.normalize_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.normalize_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this="DAY")) -@UNARY_OP_REGISTRATION.register(ops.notnull_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.notnull_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Not(this=sge.Is(this=expr.expr, expression=sge.Null())) -@UNARY_OP_REGISTRATION.register(ops.obj_fetch_metadata_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.obj_fetch_metadata_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("OBJ.FETCH_METADATA", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.ObjGetAccessUrl) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ObjGetAccessUrl) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("OBJ.GET_ACCESS_URL", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.pos_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.pos_op) +def _(expr: TypedExpr) -> sge.Expression: return expr.expr -@UNARY_OP_REGISTRATION.register(ops.quarter_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.quarter_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="QUARTER"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.ReplaceStrOp) -def _(op: ops.ReplaceStrOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ReplaceStrOp, pass_op=True) +def _(expr: TypedExpr, op: ops.ReplaceStrOp) -> sge.Expression: return sge.func("REPLACE", expr.expr, sge.convert(op.pat), sge.convert(op.repl)) -@UNARY_OP_REGISTRATION.register(ops.RegexReplaceStrOp) -def _(op: ops.RegexReplaceStrOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.RegexReplaceStrOp, pass_op=True) +def _(expr: TypedExpr, op: ops.RegexReplaceStrOp) -> sge.Expression: return sge.func( "REGEXP_REPLACE", expr.expr, sge.convert(op.pat), sge.convert(op.repl) ) -@UNARY_OP_REGISTRATION.register(ops.reverse_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.reverse_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("REVERSE", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.second_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.second_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="SECOND"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.StrRstripOp) -def _(op: ops.StrRstripOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrRstripOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression: return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="RIGHT") -@UNARY_OP_REGISTRATION.register(ops.sqrt_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.sqrt_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -668,8 +664,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.StartsWithOp) -def _(op: ops.StartsWithOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StartsWithOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StartsWithOp) -> sge.Expression: if not op.pat: return sge.false() @@ -680,18 +676,18 @@ def to_startswith(pat: str) -> sge.Expression: return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions) -@UNARY_OP_REGISTRATION.register(ops.StrStripOp) -def _(op: ops.StrStripOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrStripOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrStripOp) -> sge.Expression: return sge.Trim(this=sge.convert(op.to_strip), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.sin_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.sin_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("SIN", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.sinh_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.sinh_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( @@ -703,13 +699,13 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.StringSplitOp) -def _(op: ops.StringSplitOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StringSplitOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression: return sge.Split(this=expr.expr, expression=sge.convert(op.pat)) -@UNARY_OP_REGISTRATION.register(ops.StrGetOp) -def _(op: ops.StrGetOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrGetOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrGetOp) -> sge.Expression: return sge.Substring( this=expr.expr, start=sge.convert(op.i + 1), @@ -717,8 +713,8 @@ def _(op: ops.StrGetOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.StrSliceOp) -def _(op: ops.StrSliceOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrSliceOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression: start = op.start + 1 if op.start is not None else None if op.end is None: length = None @@ -733,13 +729,13 @@ def _(op: ops.StrSliceOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.StrftimeOp) -def _(op: ops.StrftimeOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StrftimeOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrftimeOp) -> sge.Expression: return sge.func("FORMAT_TIMESTAMP", sge.convert(op.date_format), expr.expr) -@UNARY_OP_REGISTRATION.register(ops.StructFieldOp) -def _(op: ops.StructFieldOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.StructFieldOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StructFieldOp) -> sge.Expression: if isinstance(op.name_or_index, str): name = op.name_or_index else: @@ -753,38 +749,38 @@ def _(op: ops.StructFieldOp, expr: TypedExpr) -> sge.Expression: ) -@UNARY_OP_REGISTRATION.register(ops.tan_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.tan_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("TAN", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.tanh_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.tanh_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("TANH", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.time_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.time_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("TIME", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.timedelta_floor_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.timedelta_floor_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Floor(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.ToDatetimeOp) -def _(op: ops.ToDatetimeOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ToDatetimeOp) +def _(expr: TypedExpr) -> sge.Expression: return sge.Cast(this=sge.func("TIMESTAMP_SECONDS", expr.expr), to="DATETIME") -@UNARY_OP_REGISTRATION.register(ops.ToTimestampOp) -def _(op: ops.ToTimestampOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ToTimestampOp) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("TIMESTAMP_SECONDS", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.ToTimedeltaOp) -def _(op: ops.ToTimedeltaOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ToTimedeltaOp, pass_op=True) +def _(expr: TypedExpr, op: ops.ToTimedeltaOp) -> sge.Expression: value = expr.expr factor = UNIT_TO_US_CONVERSION_FACTORS[op.unit] if factor != 1: @@ -792,78 +788,78 @@ def _(op: ops.ToTimedeltaOp, expr: TypedExpr) -> sge.Expression: return value -@UNARY_OP_REGISTRATION.register(ops.UnixMicros) -def _(op: ops.UnixMicros, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.UnixMicros) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("UNIX_MICROS", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.UnixMillis) -def _(op: ops.UnixMillis, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.UnixMillis) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("UNIX_MILLIS", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.UnixSeconds) -def _(op: ops.UnixSeconds, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.UnixSeconds, pass_op=True) +def _(expr: TypedExpr, op: ops.UnixSeconds) -> sge.Expression: return sge.func("UNIX_SECONDS", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.JSONExtract) -def _(op: ops.JSONExtract, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.JSONExtract, pass_op=True) +def _(expr: TypedExpr, op: ops.JSONExtract) -> sge.Expression: return sge.func("JSON_EXTRACT", expr.expr, sge.convert(op.json_path)) -@UNARY_OP_REGISTRATION.register(ops.JSONExtractArray) -def _(op: ops.JSONExtractArray, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.JSONExtractArray, pass_op=True) +def _(expr: TypedExpr, op: ops.JSONExtractArray) -> sge.Expression: return sge.func("JSON_EXTRACT_ARRAY", expr.expr, sge.convert(op.json_path)) -@UNARY_OP_REGISTRATION.register(ops.JSONExtractStringArray) -def _(op: ops.JSONExtractStringArray, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.JSONExtractStringArray, pass_op=True) +def _(expr: TypedExpr, op: ops.JSONExtractStringArray) -> sge.Expression: return sge.func("JSON_EXTRACT_STRING_ARRAY", expr.expr, sge.convert(op.json_path)) -@UNARY_OP_REGISTRATION.register(ops.JSONQuery) -def _(op: ops.JSONQuery, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.JSONQuery, pass_op=True) +def _(expr: TypedExpr, op: ops.JSONQuery) -> sge.Expression: return sge.func("JSON_QUERY", expr.expr, sge.convert(op.json_path)) -@UNARY_OP_REGISTRATION.register(ops.JSONQueryArray) -def _(op: ops.JSONQueryArray, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.JSONQueryArray, pass_op=True) +def _(expr: TypedExpr, op: ops.JSONQueryArray) -> sge.Expression: return sge.func("JSON_QUERY_ARRAY", expr.expr, sge.convert(op.json_path)) -@UNARY_OP_REGISTRATION.register(ops.JSONValue) -def _(op: ops.JSONValue, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.JSONValue, pass_op=True) +def _(expr: TypedExpr, op: ops.JSONValue) -> sge.Expression: return sge.func("JSON_VALUE", expr.expr, sge.convert(op.json_path)) -@UNARY_OP_REGISTRATION.register(ops.JSONValueArray) -def _(op: ops.JSONValueArray, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.JSONValueArray, pass_op=True) +def _(expr: TypedExpr, op: ops.JSONValueArray) -> sge.Expression: return sge.func("JSON_VALUE_ARRAY", expr.expr, sge.convert(op.json_path)) -@UNARY_OP_REGISTRATION.register(ops.ParseJSON) -def _(op: ops.ParseJSON, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ParseJSON) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("PARSE_JSON", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.ToJSONString) -def _(op: ops.ToJSONString, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ToJSONString) +def _(expr: TypedExpr) -> sge.Expression: return sge.func("TO_JSON_STRING", expr.expr) -@UNARY_OP_REGISTRATION.register(ops.upper_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.upper_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Upper(this=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.year_op) -def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.year_op) +def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.ZfillOp) -def _(op: ops.ZfillOp, expr: TypedExpr) -> sge.Expression: +@register_unary_op(ops.ZfillOp, pass_op=True) +def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression: return sge.Case( ifs=[ sge.If( diff --git a/bigframes/core/compile/sqlglot/scalar_compiler.py b/bigframes/core/compile/sqlglot/scalar_compiler.py index 65c2501b71..3e12da6d92 100644 --- a/bigframes/core/compile/sqlglot/scalar_compiler.py +++ b/bigframes/core/compile/sqlglot/scalar_compiler.py @@ -14,60 +14,169 @@ from __future__ import annotations import functools +import typing import sqlglot.expressions as sge -from bigframes.core import expression -from bigframes.core.compile.sqlglot.expressions import ( - binary_compiler, - nary_compiler, - ternary_compiler, - typed_expr, - unary_compiler, -) +from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.sqlglot_ir as ir +import bigframes.core.expression as ex import bigframes.operations as ops -@functools.singledispatch -def compile_scalar_expression( - expr: expression.Expression, -) -> sge.Expression: - """Compiles BigFrames scalar expression into SQLGlot expression.""" - raise ValueError(f"Can't compile unrecognized node: {expression}") - - -@compile_scalar_expression.register -def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression: - return sge.Column(this=sge.to_identifier(expr.id.sql, quoted=True)) - - -@compile_scalar_expression.register -def compile_constant_expression( - expr: expression.ScalarConstantExpression, -) -> sge.Expression: - return ir._literal(expr.value, expr.dtype) - - -@compile_scalar_expression.register -def compile_op_expression(expr: expression.OpExpression) -> sge.Expression: - # Non-recursively compiles the children scalar expressions. - args = tuple( - typed_expr.TypedExpr(compile_scalar_expression(input), input.output_type) - for input in expr.inputs - ) - - op = expr.op - if isinstance(op, ops.UnaryOp): - return unary_compiler.compile(op, args[0]) - elif isinstance(op, ops.BinaryOp): - return binary_compiler.compile(op, args[0], args[1]) - elif isinstance(op, ops.TernaryOp): - return ternary_compiler.compile(op, args[0], args[1], args[2]) - elif isinstance(op, ops.NaryOp): - return nary_compiler.compile(op, *args) - else: - raise TypeError( - f"Operator '{op.name}' has an unrecognized arity or type " - "and cannot be compiled." +class ScalarOpCompiler: + # Mapping of operation name to implemenations + _registry: dict[ + str, + typing.Callable[[typing.Sequence[TypedExpr], ops.RowOp], sge.Expression], + ] = {} + + @functools.singledispatchmethod + def compile_expression( + self, + expression: ex.Expression, + ) -> sge.Expression: + """Compiles BigFrames scalar expression into SQLGlot expression.""" + raise NotImplementedError(f"Unrecognized expression: {expression}") + + @compile_expression.register + def _(self, expr: ex.DerefOp) -> sge.Expression: + return sge.Column(this=sge.to_identifier(expr.id.sql, quoted=True)) + + @compile_expression.register + def _(self, expr: ex.ScalarConstantExpression) -> sge.Expression: + return ir._literal(expr.value, expr.dtype) + + @compile_expression.register + def _(self, expr: ex.OpExpression) -> sge.Expression: + # Non-recursively compiles the children scalar expressions. + inputs = tuple( + TypedExpr(self.compile_expression(sub_expr), sub_expr.output_type) + for sub_expr in expr.inputs ) + return self.compile_row_op(expr.op, inputs) + + def compile_row_op( + self, op: ops.RowOp, inputs: typing.Sequence[TypedExpr] + ) -> sge.Expression: + impl = self._registry[op.name] + return impl(inputs, op) + + def register_unary_op( + self, + op_ref: typing.Union[ops.UnaryOp, type[ops.UnaryOp]], + pass_op: bool = False, + ): + """ + Decorator to register a unary op implementation. + + Args: + op_ref (UnaryOp or UnaryOp type): + Class or instance of operator that is implemented by the decorated function. + pass_op (bool): + Set to true if implementation takes the operator object as the last argument. + This is needed for parameterized ops where parameters are part of op object. + """ + key = typing.cast(str, op_ref.name) + + def decorator(impl: typing.Callable[..., TypedExpr]): + def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp): + if pass_op: + return impl(args[0], op) + else: + return impl(args[0]) + + self._register(key, normalized_impl) + return impl + + return decorator + + def register_binary_op( + self, + op_ref: typing.Union[ops.BinaryOp, type[ops.BinaryOp]], + pass_op: bool = False, + ): + """ + Decorator to register a binary op implementation. + + Args: + op_ref (BinaryOp or BinaryOp type): + Class or instance of operator that is implemented by the decorated function. + pass_op (bool): + Set to true if implementation takes the operator object as the last argument. + This is needed for parameterized ops where parameters are part of op object. + """ + key = typing.cast(str, op_ref.name) + + def decorator(impl: typing.Callable[..., TypedExpr]): + def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp): + if pass_op: + return impl(args[0], args[1], op) + else: + return impl(args[0], args[1]) + + self._register(key, normalized_impl) + return impl + + return decorator + + def register_ternary_op( + self, op_ref: typing.Union[ops.TernaryOp, type[ops.TernaryOp]] + ): + """ + Decorator to register a ternary op implementation. + + Args: + op_ref (TernaryOp or TernaryOp type): + Class or instance of operator that is implemented by the decorated function. + """ + key = typing.cast(str, op_ref.name) + + def decorator(impl: typing.Callable[..., TypedExpr]): + def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp): + return impl(args[0], args[1], args[2]) + + self._register(key, normalized_impl) + return impl + + return decorator + + def register_nary_op( + self, op_ref: typing.Union[ops.NaryOp, type[ops.NaryOp]], pass_op: bool = False + ): + """ + Decorator to register a nary op implementation. + + Args: + op_ref (NaryOp or NaryOp type): + Class or instance of operator that is implemented by the decorated function. + pass_op (bool): + Set to true if implementation takes the operator object as the last argument. + This is needed for parameterized ops where parameters are part of op object. + """ + key = typing.cast(str, op_ref.name) + + def decorator(impl: typing.Callable[..., TypedExpr]): + def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp): + if pass_op: + return impl(*args, op=op) + else: + return impl(*args) + + self._register(key, normalized_impl) + return impl + + return decorator + + def _register( + self, + op_name: str, + impl: typing.Callable[[typing.Sequence[TypedExpr], ops.RowOp], sge.Expression], + ): + if op_name in self._registry: + raise ValueError(f"Operation name {op_name} already registered") + self._registry[op_name] = impl + + +# Singleton compiler +scalar_op_compiler = ScalarOpCompiler() diff --git a/tests/unit/core/compile/sqlglot/expressions/test_op_registration.py b/tests/unit/core/compile/sqlglot/expressions/test_op_registration.py deleted file mode 100644 index 1c49dde6ca..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/test_op_registration.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from sqlglot import expressions as sge - -from bigframes.core.compile.sqlglot.expressions import op_registration -from bigframes.operations import numeric_ops - - -def test_register_then_get(): - reg = op_registration.OpRegistration() - input = sge.to_identifier("A") - op = numeric_ops.add_op - - @reg.register(numeric_ops.AddOp) - def test_func(op: numeric_ops.AddOp, input: sge.Expression) -> sge.Expression: - return input - - assert reg[numeric_ops.add_op](op, input) == test_func(op, input) - assert reg[numeric_ops.add_op.name](op, input) == test_func(op, input) - - -def test_register_function_first_argument_is_not_scalar_op_raise_error(): - reg = op_registration.OpRegistration() - - @reg.register(numeric_ops.AddOp) - def test_func(input: sge.Expression) -> sge.Expression: - return input - - with pytest.raises(ValueError, match=r".*first parameter must be an operator.*"): - test_func(sge.to_identifier("A")) From d32d8d211473811685103cb650f25c829b34989b Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 12 Sep 2025 18:54:00 +0000 Subject: [PATCH 2/3] add unit tests --- .../compile/sqlglot/test_scalar_compiler.py | 189 ++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/test_scalar_compiler.py diff --git a/tests/unit/core/compile/sqlglot/test_scalar_compiler.py b/tests/unit/core/compile/sqlglot/test_scalar_compiler.py new file mode 100644 index 0000000000..a2ee2c6331 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/test_scalar_compiler.py @@ -0,0 +1,189 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest.mock as mock + +import pytest +import sqlglot.expressions as sge + +from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr +import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler +import bigframes.operations as ops + + +def test_register_unary_op(): + compiler = scalar_compiler.ScalarOpCompiler() + + class MockUnaryOp(ops.UnaryOp): + name = "mock_unary_op" + + mock_op = MockUnaryOp() + mock_impl = mock.Mock() + + @compiler.register_unary_op(mock_op) + def _(expr: TypedExpr) -> sge.Expression: + mock_impl(expr) + return sge.Identifier(this="output") + + arg = TypedExpr(sge.Identifier(this="input"), "string") + result = compiler.compile_row_op(mock_op, [arg]) + assert result == sge.Identifier(this="output") + mock_impl.assert_called_once_with(arg) + + +def test_register_unary_op_pass_op(): + compiler = scalar_compiler.ScalarOpCompiler() + + class MockUnaryOp(ops.UnaryOp): + name = "mock_unary_op_pass_op" + + mock_op = MockUnaryOp() + mock_impl = mock.Mock() + + @compiler.register_unary_op(mock_op, pass_op=True) + def _(expr: TypedExpr, op: ops.UnaryOp) -> sge.Expression: + mock_impl(expr, op) + return sge.Identifier(this="output") + + arg = TypedExpr(sge.Identifier(this="input"), "string") + result = compiler.compile_row_op(mock_op, [arg]) + assert result == sge.Identifier(this="output") + mock_impl.assert_called_once_with(arg, mock_op) + + +def test_register_binary_op(): + compiler = scalar_compiler.ScalarOpCompiler() + + class MockBinaryOp(ops.BinaryOp): + name = "mock_binary_op" + + mock_op = MockBinaryOp() + mock_impl = mock.Mock() + + @compiler.register_binary_op(mock_op) + def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + mock_impl(left, right) + return sge.Identifier(this="output") + + arg1 = TypedExpr(sge.Identifier(this="input1"), "string") + arg2 = TypedExpr(sge.Identifier(this="input2"), "string") + result = compiler.compile_row_op(mock_op, [arg1, arg2]) + assert result == sge.Identifier(this="output") + mock_impl.assert_called_once_with(arg1, arg2) + + +def test_register_binary_op_pass_on(): + compiler = scalar_compiler.ScalarOpCompiler() + + class MockBinaryOp(ops.BinaryOp): + name = "mock_binary_op_pass_op" + + mock_op = MockBinaryOp() + mock_impl = mock.Mock() + + @compiler.register_binary_op(mock_op, pass_op=True) + def _(left: TypedExpr, right: TypedExpr, op: ops.BinaryOp) -> sge.Expression: + mock_impl(left, right, op) + return sge.Identifier(this="output") + + arg1 = TypedExpr(sge.Identifier(this="input1"), "string") + arg2 = TypedExpr(sge.Identifier(this="input2"), "string") + result = compiler.compile_row_op(mock_op, [arg1, arg2]) + assert result == sge.Identifier(this="output") + mock_impl.assert_called_once_with(arg1, arg2, mock_op) + + +def test_register_ternary_op(): + compiler = scalar_compiler.ScalarOpCompiler() + + class MockTernaryOp(ops.TernaryOp): + name = "mock_ternary_op" + + mock_op = MockTernaryOp() + mock_impl = mock.Mock() + + @compiler.register_ternary_op(mock_op) + def _(arg1: TypedExpr, arg2: TypedExpr, arg3: TypedExpr) -> sge.Expression: + mock_impl(arg1, arg2, arg3) + return sge.Identifier(this="output") + + arg1 = TypedExpr(sge.Identifier(this="input1"), "string") + arg2 = TypedExpr(sge.Identifier(this="input2"), "string") + arg3 = TypedExpr(sge.Identifier(this="input3"), "string") + result = compiler.compile_row_op(mock_op, [arg1, arg2, arg3]) + assert result == sge.Identifier(this="output") + mock_impl.assert_called_once_with(arg1, arg2, arg3) + + +def test_register_nary_op(): + compiler = scalar_compiler.ScalarOpCompiler() + + class MockNaryOp(ops.NaryOp): + name = "mock_nary_op" + + mock_op = MockNaryOp() + mock_impl = mock.Mock() + + @compiler.register_nary_op(mock_op) + def _(*args: TypedExpr) -> sge.Expression: + mock_impl(*args) + return sge.Identifier(this="output") + + arg1 = TypedExpr(sge.Identifier(this="input1"), "string") + arg2 = TypedExpr(sge.Identifier(this="input2"), "string") + result = compiler.compile_row_op(mock_op, [arg1, arg2]) + assert result == sge.Identifier(this="output") + mock_impl.assert_called_once_with(arg1, arg2) + + +def test_register_nary_op_pass_on(): + compiler = scalar_compiler.ScalarOpCompiler() + + class MockNaryOp(ops.NaryOp): + name = "mock_nary_op_pass_op" + + mock_op = MockNaryOp() + mock_impl = mock.Mock() + + @compiler.register_nary_op(mock_op, pass_op=True) + def _(*args: TypedExpr, op: ops.NaryOp) -> sge.Expression: + mock_impl(*args, op=op) + return sge.Identifier(this="output") + + arg1 = TypedExpr(sge.Identifier(this="input1"), "string") + arg2 = TypedExpr(sge.Identifier(this="input2"), "string") + arg3 = TypedExpr(sge.Identifier(this="input3"), "string") + arg4 = TypedExpr(sge.Identifier(this="input4"), "string") + result = compiler.compile_row_op(mock_op, [arg1, arg2, arg3, arg4]) + assert result == sge.Identifier(this="output") + mock_impl.assert_called_once_with(arg1, arg2, arg3, arg4, op=mock_op) + + +def test_register_duplicate_op_raises(): + compiler = scalar_compiler.ScalarOpCompiler() + + class MockUnaryOp(ops.UnaryOp): + name = "mock_unary_op_duplicate" + + mock_op = MockUnaryOp() + + @compiler.register_unary_op(mock_op) + def _(expr: TypedExpr) -> sge.Expression: + return sge.Identifier(this="output") + + with pytest.raises(ValueError): + + @compiler.register_unary_op(mock_op) + def _(expr: TypedExpr) -> sge.Expression: + return sge.Identifier(this="output2") From edd54c93ac0bec506ba97ef79f2d736030e55d36 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 15 Sep 2025 21:23:51 +0000 Subject: [PATCH 3/3] fix merge conflicts --- bigframes/core/compile/sqlglot/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index ba2c644689..40795bbb48 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -254,11 +254,11 @@ def compile_isin_join( ) -> ir.SQLGlotIR: conditions = ( typed_expr.TypedExpr( - scalar_compiler.compile_scalar_expression(node.left_col), + scalar_compiler.scalar_op_compiler.compile_expression(node.left_col), node.left_col.output_type, ), typed_expr.TypedExpr( - scalar_compiler.compile_scalar_expression(node.right_col), + scalar_compiler.scalar_op_compiler.compile_expression(node.right_col), node.right_col.output_type, ), )