From ef6d580249deffca2998b763d3b4572ad271a630 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 30 Jun 2025 21:46:41 +0000 Subject: [PATCH 1/3] refactor: subclass DerefOp as ResolvedDerefOp --- bigframes/core/compile/polars/compiler.py | 2 +- .../core/compile/sqlglot/scalar_compiler.py | 7 ---- bigframes/core/expression.py | 32 +++++++------------ bigframes/core/rewrite/schema_binding.py | 20 ++++++++++++ tests/unit/core/test_expression.py | 2 +- 5 files changed, 34 insertions(+), 29 deletions(-) diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 6b76f3f53d..78fa0d2c48 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -120,7 +120,7 @@ def _( @compile_expression.register def _( self, - expression: ex.SchemaFieldRefExpression, + expression: ex.ResolvedDerefOp, ) -> pl.Expr: return pl.col(expression.field.id.sql) diff --git a/bigframes/core/compile/sqlglot/scalar_compiler.py b/bigframes/core/compile/sqlglot/scalar_compiler.py index f553518300..0db507b0fa 100644 --- a/bigframes/core/compile/sqlglot/scalar_compiler.py +++ b/bigframes/core/compile/sqlglot/scalar_compiler.py @@ -42,13 +42,6 @@ def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression: return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True)) -@compile_scalar_expression.register -def compile_field_ref_expression( - expr: expression.SchemaFieldRefExpression, -) -> sge.Expression: - return sge.ColumnDef(this=sge.to_identifier(expr.field.id.sql, quoted=True)) - - @compile_scalar_expression.register def compile_constant_expression( expr: expression.ScalarConstantExpression, diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 40ba70c555..642f413aac 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -429,19 +429,23 @@ def transform_children(self, t: Callable[[Expression], Expression]) -> Expressio @dataclasses.dataclass(frozen=True) -class SchemaFieldRefExpression(Expression): - """An expression representing a schema field. This is essentially a DerefOp with input schema bound.""" +class ResolvedDerefOp(DerefOp): + """An expression that refers to a column by ID and resolved with schema bound.""" field: field.Field + # Re-declare 'id' from the parent to remove it from the __init__ method + id: ids.ColumnId = dataclasses.field(init=False) + + def __post_init__(self): + # Initialize the parent's 'id' field after the object is created. + # We must use object.__setattr__ because the dataclass is frozen. + object.__setattr__(self, "id", self.field.id) + @property def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: return (self.field.id,) - @property - def is_const(self) -> bool: - return False - @property def nullable(self) -> bool: return self.field.nullable @@ -464,21 +468,11 @@ def bind_refs( bindings: Mapping[ids.ColumnId, Expression], allow_partial_bindings: bool = False, ) -> Expression: + # TODO: Check if we can remove. if self.field.id in bindings.keys(): return bindings[self.field.id] return self - @property - def is_bijective(self) -> bool: - return True - - @property - def is_identity(self) -> bool: - return True - - def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: - return self - @dataclasses.dataclass(frozen=True) class OpExpression(Expression): @@ -588,9 +582,7 @@ def bind_schema_fields( if expr.is_resolved: return expr - expr_by_id = { - id: SchemaFieldRefExpression(field) for id, field in field_by_id.items() - } + expr_by_id = {id: ResolvedDerefOp(field) for id, field in field_by_id.items()} return expr.bind_refs(expr_by_id) diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index aa5cb986b9..a713df234f 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +import typing from bigframes.core import bigframe_node from bigframes.core import expression as ex @@ -52,4 +53,23 @@ def bind_schema_to_node( return dataclasses.replace(node, by=tuple(bound_bys)) + if isinstance(node, nodes.JoinNode): + conditions = tuple( + ( + typing.cast( + ex.ResolvedDerefOp, + ex.bind_schema_fields(left, node.left_child.field_by_id), + ), + typing.cast( + ex.ResolvedDerefOp, + ex.bind_schema_fields(right, node.right_child.field_by_id), + ), + ) + for left, right in node.conditions + ) + return dataclasses.replace( + node, + conditions=conditions, + ) + return node diff --git a/tests/unit/core/test_expression.py b/tests/unit/core/test_expression.py index 9534c8605a..c3f91cbb59 100644 --- a/tests/unit/core/test_expression.py +++ b/tests/unit/core/test_expression.py @@ -77,7 +77,7 @@ def test_deref_op_dtype_resolution(): def test_field_ref_expr_dtype_resolution_short_circuit(): - expression = ex.SchemaFieldRefExpression( + expression = ex.ResolvedDerefOp( field.Field(ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE) ) field_bindings = _create_field_bindings({"anotherCol": dtypes.STRING_DTYPE}) From 294510f0cbff3a9abadcc0ef7d24454c5eb643be Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 30 Jun 2025 22:28:35 +0000 Subject: [PATCH 2/3] replace the `field` attribute by id, dtype, nullable --- bigframes/core/compile/polars/compiler.py | 2 +- bigframes/core/expression.py | 31 ++++++++--------------- tests/unit/core/test_expression.py | 4 +-- 3 files changed, 13 insertions(+), 24 deletions(-) diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 78fa0d2c48..40037735d4 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -122,7 +122,7 @@ def _( self, expression: ex.ResolvedDerefOp, ) -> pl.Expr: - return pl.col(expression.field.id.sql) + return pl.col(expression.id.sql) @compile_expression.register def _( diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 642f413aac..726d0fbc65 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -432,23 +432,12 @@ def transform_children(self, t: Callable[[Expression], Expression]) -> Expressio class ResolvedDerefOp(DerefOp): """An expression that refers to a column by ID and resolved with schema bound.""" - field: field.Field + dtype: dtypes.Dtype + nullable: bool = True - # Re-declare 'id' from the parent to remove it from the __init__ method - id: ids.ColumnId = dataclasses.field(init=False) - - def __post_init__(self): - # Initialize the parent's 'id' field after the object is created. - # We must use object.__setattr__ because the dataclass is frozen. - object.__setattr__(self, "id", self.field.id) - - @property - def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: - return (self.field.id,) - - @property - def nullable(self) -> bool: - return self.field.nullable + @classmethod + def from_field(cls, f: field.Field): + return cls(f.id, f.dtype, f.nullable) @property def is_resolved(self) -> bool: @@ -456,7 +445,7 @@ def is_resolved(self) -> bool: @property def output_type(self) -> dtypes.ExpressionType: - return self.field.dtype + return self.dtype def bind_variables( self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False @@ -469,8 +458,8 @@ def bind_refs( allow_partial_bindings: bool = False, ) -> Expression: # TODO: Check if we can remove. - if self.field.id in bindings.keys(): - return bindings[self.field.id] + if self.id in bindings.keys(): + return bindings[self.id] return self @@ -582,7 +571,9 @@ def bind_schema_fields( if expr.is_resolved: return expr - expr_by_id = {id: ResolvedDerefOp(field) for id, field in field_by_id.items()} + expr_by_id = { + id: ResolvedDerefOp.from_field(field) for id, field in field_by_id.items() + } return expr.bind_refs(expr_by_id) diff --git a/tests/unit/core/test_expression.py b/tests/unit/core/test_expression.py index c3f91cbb59..7e18475b37 100644 --- a/tests/unit/core/test_expression.py +++ b/tests/unit/core/test_expression.py @@ -77,9 +77,7 @@ def test_deref_op_dtype_resolution(): def test_field_ref_expr_dtype_resolution_short_circuit(): - expression = ex.ResolvedDerefOp( - field.Field(ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE) - ) + expression = ex.ResolvedDerefOp(id=ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE) field_bindings = _create_field_bindings({"anotherCol": dtypes.STRING_DTYPE}) result = ex.bind_schema_fields(expression, field_bindings) From 0f20a6a180b71ca1afd598ca4767785cb5207763 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 1 Jul 2025 17:13:06 +0000 Subject: [PATCH 3/3] final cleanup --- bigframes/core/expression.py | 23 ++++++----------------- bigframes/core/rewrite/schema_binding.py | 11 ++--------- tests/unit/core/test_expression.py | 4 +++- 3 files changed, 11 insertions(+), 27 deletions(-) diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 726d0fbc65..7b20e430ff 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -433,35 +433,24 @@ class ResolvedDerefOp(DerefOp): """An expression that refers to a column by ID and resolved with schema bound.""" dtype: dtypes.Dtype - nullable: bool = True + is_nullable: bool @classmethod def from_field(cls, f: field.Field): - return cls(f.id, f.dtype, f.nullable) + return cls(id=f.id, dtype=f.dtype, is_nullable=f.nullable) @property def is_resolved(self) -> bool: return True + @property + def nullable(self) -> bool: + return self.is_nullable + @property def output_type(self) -> dtypes.ExpressionType: return self.dtype - def bind_variables( - self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False - ) -> Expression: - return self - - def bind_refs( - self, - bindings: Mapping[ids.ColumnId, Expression], - allow_partial_bindings: bool = False, - ) -> Expression: - # TODO: Check if we can remove. - if self.id in bindings.keys(): - return bindings[self.id] - return self - @dataclasses.dataclass(frozen=True) class OpExpression(Expression): diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index a713df234f..af0593211c 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -13,7 +13,6 @@ # limitations under the License. import dataclasses -import typing from bigframes.core import bigframe_node from bigframes.core import expression as ex @@ -56,14 +55,8 @@ def bind_schema_to_node( if isinstance(node, nodes.JoinNode): conditions = tuple( ( - typing.cast( - ex.ResolvedDerefOp, - ex.bind_schema_fields(left, node.left_child.field_by_id), - ), - typing.cast( - ex.ResolvedDerefOp, - ex.bind_schema_fields(right, node.right_child.field_by_id), - ), + ex.ResolvedDerefOp.from_field(node.left_child.field_by_id[left.id]), + ex.ResolvedDerefOp.from_field(node.right_child.field_by_id[right.id]), ) for left, right in node.conditions ) diff --git a/tests/unit/core/test_expression.py b/tests/unit/core/test_expression.py index 7e18475b37..4c3d233879 100644 --- a/tests/unit/core/test_expression.py +++ b/tests/unit/core/test_expression.py @@ -77,7 +77,9 @@ def test_deref_op_dtype_resolution(): def test_field_ref_expr_dtype_resolution_short_circuit(): - expression = ex.ResolvedDerefOp(id=ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE) + expression = ex.ResolvedDerefOp( + id=ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE, is_nullable=True + ) field_bindings = _create_field_bindings({"anotherCol": dtypes.STRING_DTYPE}) result = ex.bind_schema_fields(expression, field_bindings)