|
19 | 19 | import functools |
20 | 20 | import itertools |
21 | 21 | import typing |
22 | | -from typing import Generator, Mapping, TypeVar, Union |
| 22 | +from typing import Callable, Generator, Mapping, TypeVar, Union |
23 | 23 |
|
24 | 24 | import pandas as pd |
25 | 25 |
|
@@ -249,6 +249,10 @@ def is_identity(self) -> bool: |
249 | 249 | """True for identity operation that does not transform input.""" |
250 | 250 | return False |
251 | 251 |
|
| 252 | + @abc.abstractmethod |
| 253 | + def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: |
| 254 | + ... |
| 255 | + |
252 | 256 | def walk(self) -> Generator[Expression, None, None]: |
253 | 257 | yield self |
254 | 258 | for child in self.children: |
@@ -311,6 +315,9 @@ def __eq__(self, other): |
311 | 315 |
|
312 | 316 | return self.value == other.value and self.dtype == other.dtype |
313 | 317 |
|
| 318 | + def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: |
| 319 | + return self |
| 320 | + |
314 | 321 |
|
315 | 322 | @dataclasses.dataclass(frozen=True) |
316 | 323 | class UnboundVariableExpression(Expression): |
@@ -362,6 +369,9 @@ def is_bijective(self) -> bool: |
362 | 369 | def is_identity(self) -> bool: |
363 | 370 | return True |
364 | 371 |
|
| 372 | + def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: |
| 373 | + return self |
| 374 | + |
365 | 375 |
|
366 | 376 | @dataclasses.dataclass(frozen=True) |
367 | 377 | class DerefOp(Expression): |
@@ -414,6 +424,9 @@ def is_bijective(self) -> bool: |
414 | 424 | def is_identity(self) -> bool: |
415 | 425 | return True |
416 | 426 |
|
| 427 | + def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: |
| 428 | + return self |
| 429 | + |
417 | 430 |
|
418 | 431 | @dataclasses.dataclass(frozen=True) |
419 | 432 | class SchemaFieldRefExpression(Expression): |
@@ -463,12 +476,15 @@ def is_bijective(self) -> bool: |
463 | 476 | def is_identity(self) -> bool: |
464 | 477 | return True |
465 | 478 |
|
| 479 | + def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: |
| 480 | + return self |
| 481 | + |
466 | 482 |
|
467 | 483 | @dataclasses.dataclass(frozen=True) |
468 | 484 | class OpExpression(Expression): |
469 | 485 | """An expression representing a scalar operation applied to 1 or more argument sub-expressions.""" |
470 | 486 |
|
471 | | - op: bigframes.operations.RowOp |
| 487 | + op: bigframes.operations.ScalarOp |
472 | 488 | inputs: typing.Tuple[Expression, ...] |
473 | 489 |
|
474 | 490 | @property |
@@ -553,6 +569,12 @@ def deterministic(self) -> bool: |
553 | 569 | all(input.deterministic for input in self.inputs) and self.op.deterministic |
554 | 570 | ) |
555 | 571 |
|
| 572 | + def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: |
| 573 | + new_inputs = tuple(t(input) for input in self.inputs) |
| 574 | + if new_inputs != self.inputs: |
| 575 | + return dataclasses.replace(self, inputs=new_inputs) |
| 576 | + return self |
| 577 | + |
556 | 578 |
|
557 | 579 | def bind_schema_fields( |
558 | 580 | expr: Expression, field_by_id: Mapping[ids.ColumnId, field.Field] |
|
0 commit comments