Skip to content

Commit 5ff27a9

Browse files
committed
refactor: subclass DerefOp as ResolvedDerefOp
1 parent bb98178 commit 5ff27a9

File tree

5 files changed

+34
-29
lines changed

5 files changed

+34
-29
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _(
120120
@compile_expression.register
121121
def _(
122122
self,
123-
expression: ex.SchemaFieldRefExpression,
123+
expression: ex.ResolvedDerefOp,
124124
) -> pl.Expr:
125125
return pl.col(expression.field.id.sql)
126126

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,6 @@ def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression:
4242
return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True))
4343

4444

45-
@compile_scalar_expression.register
46-
def compile_field_ref_expression(
47-
expr: expression.SchemaFieldRefExpression,
48-
) -> sge.Expression:
49-
return sge.ColumnDef(this=sge.to_identifier(expr.field.id.sql, quoted=True))
50-
51-
5245
@compile_scalar_expression.register
5346
def compile_constant_expression(
5447
expr: expression.ScalarConstantExpression,

bigframes/core/expression.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -429,19 +429,23 @@ def transform_children(self, t: Callable[[Expression], Expression]) -> Expressio
429429

430430

431431
@dataclasses.dataclass(frozen=True)
432-
class SchemaFieldRefExpression(Expression):
433-
"""An expression representing a schema field. This is essentially a DerefOp with input schema bound."""
432+
class ResolvedDerefOp(DerefOp):
433+
"""An expression that refers to a column by ID and resolved with schema bound."""
434434

435435
field: field.Field
436436

437+
# Re-declare 'id' from the parent to remove it from the __init__ method
438+
id: ids.ColumnId = dataclasses.field(init=False)
439+
440+
def __post_init__(self):
441+
# Initialize the parent's 'id' field after the object is created.
442+
# We must use object.__setattr__ because the dataclass is frozen.
443+
object.__setattr__(self, "id", self.field.id)
444+
437445
@property
438446
def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
439447
return (self.field.id,)
440448

441-
@property
442-
def is_const(self) -> bool:
443-
return False
444-
445449
@property
446450
def nullable(self) -> bool:
447451
return self.field.nullable
@@ -464,21 +468,11 @@ def bind_refs(
464468
bindings: Mapping[ids.ColumnId, Expression],
465469
allow_partial_bindings: bool = False,
466470
) -> Expression:
471+
# TODO: Check if we can remove.
467472
if self.field.id in bindings.keys():
468473
return bindings[self.field.id]
469474
return self
470475

471-
@property
472-
def is_bijective(self) -> bool:
473-
return True
474-
475-
@property
476-
def is_identity(self) -> bool:
477-
return True
478-
479-
def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
480-
return self
481-
482476

483477
@dataclasses.dataclass(frozen=True)
484478
class OpExpression(Expression):
@@ -588,9 +582,7 @@ def bind_schema_fields(
588582
if expr.is_resolved:
589583
return expr
590584

591-
expr_by_id = {
592-
id: SchemaFieldRefExpression(field) for id, field in field_by_id.items()
593-
}
585+
expr_by_id = {id: ResolvedDerefOp(field) for id, field in field_by_id.items()}
594586
return expr.bind_refs(expr_by_id)
595587

596588

bigframes/core/rewrite/schema_binding.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16+
import typing
1617

1718
from bigframes.core import bigframe_node
1819
from bigframes.core import expression as ex
@@ -52,4 +53,23 @@ def bind_schema_to_node(
5253

5354
return dataclasses.replace(node, by=tuple(bound_bys))
5455

56+
if isinstance(node, nodes.JoinNode):
57+
conditions = tuple(
58+
(
59+
typing.cast(
60+
ex.ResolvedDerefOp,
61+
ex.bind_schema_fields(left, node.left_child.field_by_id),
62+
),
63+
typing.cast(
64+
ex.ResolvedDerefOp,
65+
ex.bind_schema_fields(right, node.right_child.field_by_id),
66+
),
67+
)
68+
for left, right in node.conditions
69+
)
70+
return dataclasses.replace(
71+
node,
72+
conditions=conditions,
73+
)
74+
5575
return node

tests/unit/core/test_expression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_deref_op_dtype_resolution():
7777

7878

7979
def test_field_ref_expr_dtype_resolution_short_circuit():
80-
expression = ex.SchemaFieldRefExpression(
80+
expression = ex.ResolvedDerefOp(
8181
field.Field(ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE)
8282
)
8383
field_bindings = _create_field_bindings({"anotherCol": dtypes.STRING_DTYPE})

0 commit comments

Comments
 (0)