diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 6b76f3f53d..40037735d4 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -120,9 +120,9 @@ def _( @compile_expression.register def _( self, - expression: ex.SchemaFieldRefExpression, + expression: ex.ResolvedDerefOp, ) -> pl.Expr: - return pl.col(expression.field.id.sql) + return pl.col(expression.id.sql) @compile_expression.register def _( diff --git a/bigframes/core/compile/sqlglot/scalar_compiler.py b/bigframes/core/compile/sqlglot/scalar_compiler.py index f553518300..0db507b0fa 100644 --- a/bigframes/core/compile/sqlglot/scalar_compiler.py +++ b/bigframes/core/compile/sqlglot/scalar_compiler.py @@ -42,13 +42,6 @@ def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression: return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True)) -@compile_scalar_expression.register -def compile_field_ref_expression( - expr: expression.SchemaFieldRefExpression, -) -> sge.Expression: - return sge.ColumnDef(this=sge.to_identifier(expr.field.id.sql, quoted=True)) - - @compile_scalar_expression.register def compile_constant_expression( expr: expression.ScalarConstantExpression, diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 40ba70c555..7b20e430ff 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -429,55 +429,27 @@ def transform_children(self, t: Callable[[Expression], Expression]) -> Expressio @dataclasses.dataclass(frozen=True) -class SchemaFieldRefExpression(Expression): - """An expression representing a schema field. This is essentially a DerefOp with input schema bound.""" +class ResolvedDerefOp(DerefOp): + """An expression that refers to a column by ID and resolved with schema bound.""" - field: field.Field + dtype: dtypes.Dtype + is_nullable: bool - @property - def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: - return (self.field.id,) - - @property - def is_const(self) -> bool: - return False - - @property - def nullable(self) -> bool: - return self.field.nullable + @classmethod + def from_field(cls, f: field.Field): + return cls(id=f.id, dtype=f.dtype, is_nullable=f.nullable) @property def is_resolved(self) -> bool: return True @property - def output_type(self) -> dtypes.ExpressionType: - return self.field.dtype - - def bind_variables( - self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False - ) -> Expression: - return self - - def bind_refs( - self, - bindings: Mapping[ids.ColumnId, Expression], - allow_partial_bindings: bool = False, - ) -> Expression: - if self.field.id in bindings.keys(): - return bindings[self.field.id] - return self - - @property - def is_bijective(self) -> bool: - return True + def nullable(self) -> bool: + return self.is_nullable @property - def is_identity(self) -> bool: - return True - - def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: - return self + def output_type(self) -> dtypes.ExpressionType: + return self.dtype @dataclasses.dataclass(frozen=True) @@ -589,7 +561,7 @@ def bind_schema_fields( return expr expr_by_id = { - id: SchemaFieldRefExpression(field) for id, field in field_by_id.items() + id: ResolvedDerefOp.from_field(field) for id, field in field_by_id.items() } return expr.bind_refs(expr_by_id) diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index aa5cb986b9..af0593211c 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -52,4 +52,17 @@ def bind_schema_to_node( return dataclasses.replace(node, by=tuple(bound_bys)) + if isinstance(node, nodes.JoinNode): + conditions = tuple( + ( + ex.ResolvedDerefOp.from_field(node.left_child.field_by_id[left.id]), + ex.ResolvedDerefOp.from_field(node.right_child.field_by_id[right.id]), + ) + for left, right in node.conditions + ) + return dataclasses.replace( + node, + conditions=conditions, + ) + return node diff --git a/tests/unit/core/test_expression.py b/tests/unit/core/test_expression.py index 9534c8605a..4c3d233879 100644 --- a/tests/unit/core/test_expression.py +++ b/tests/unit/core/test_expression.py @@ -77,8 +77,8 @@ def test_deref_op_dtype_resolution(): def test_field_ref_expr_dtype_resolution_short_circuit(): - expression = ex.SchemaFieldRefExpression( - field.Field(ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE) + expression = ex.ResolvedDerefOp( + id=ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE, is_nullable=True ) field_bindings = _create_field_bindings({"anotherCol": dtypes.STRING_DTYPE})