diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index b6483689dc..0d20509877 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -300,7 +300,15 @@ def remap_vars( def remap_refs( self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId] ) -> InNode: - return dataclasses.replace(self, left_col=self.left_col.remap_column_refs(mappings, allow_partial_bindings=True), right_col=self.right_col.remap_column_refs(mappings, allow_partial_bindings=True)) # type: ignore + return dataclasses.replace( + self, + left_col=self.left_col.remap_column_refs( + mappings, allow_partial_bindings=True + ), + right_col=self.right_col.remap_column_refs( + mappings, allow_partial_bindings=True + ), + ) # type: ignore @dataclasses.dataclass(frozen=True, eq=False) diff --git a/bigframes/core/rewrite/identifiers.py b/bigframes/core/rewrite/identifiers.py index 0093e183b4..e911d81895 100644 --- a/bigframes/core/rewrite/identifiers.py +++ b/bigframes/core/rewrite/identifiers.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import dataclasses import typing from bigframes.core import identifiers, nodes @@ -26,32 +27,68 @@ def remap_variables( nodes.BigFrameNode, dict[identifiers.ColumnId, identifiers.ColumnId], ]: - """Remaps `ColumnId`s in the BFET to produce deterministic and sequential UIDs. + """Remaps `ColumnId`s in the expression tree to be deterministic and sequential. - Note: this will convert a DAG to a tree. + This function performs a post-order traversal. It recursively remaps children + nodes first, then remaps the current node's references and definitions. + + Note: this will convert a DAG to a tree by duplicating shared nodes. + + Args: + root: The root node of the expression tree. + id_generator: An iterator that yields new column IDs. + + Returns: + A tuple of the new root node and a mapping from old to new column IDs + visible to the parent node. """ - child_replacement_map = dict() - ref_mapping = dict() - # Sequential ids are assigned bottom-up left-to-right + # Step 1: Recursively remap children to get their new nodes and ID mappings. + new_child_nodes: list[nodes.BigFrameNode] = [] + new_child_mappings: list[dict[identifiers.ColumnId, identifiers.ColumnId]] = [] for child in root.child_nodes: - new_child, child_var_mapping = remap_variables(child, id_generator=id_generator) - child_replacement_map[child] = new_child - ref_mapping.update(child_var_mapping) - - # This is actually invalid until we've replaced all of children, refs and var defs - with_new_children = root.transform_children( - lambda node: child_replacement_map[node] - ) - - with_new_refs = with_new_children.remap_refs(ref_mapping) - - node_var_mapping = {old_id: next(id_generator) for old_id in root.node_defined_ids} - with_new_vars = with_new_refs.remap_vars(node_var_mapping) - with_new_vars._validate() - - return ( - with_new_vars, - node_var_mapping - if root.defines_namespace - else (ref_mapping | node_var_mapping), - ) + new_child, child_mappings = remap_variables(child, id_generator=id_generator) + new_child_nodes.append(new_child) + new_child_mappings.append(child_mappings) + + # Step 2: Transform children to use their new nodes. + remapped_children: dict[nodes.BigFrameNode, nodes.BigFrameNode] = { + child: new_child for child, new_child in zip(root.child_nodes, new_child_nodes) + } + new_root = root.transform_children(lambda node: remapped_children[node]) + + # Step 3: Transform the current node using the mappings from its children. + downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = { + k: v for mapping in new_child_mappings for k, v in mapping.items() + } + if isinstance(new_root, nodes.InNode): + new_root = typing.cast(nodes.InNode, new_root) + new_root = dataclasses.replace( + new_root, + left_col=new_root.left_col.remap_column_refs( + new_child_mappings[0], allow_partial_bindings=True + ), + right_col=new_root.right_col.remap_column_refs( + new_child_mappings[1], allow_partial_bindings=True + ), + ) + else: + new_root = new_root.remap_refs(downstream_mappings) + + # Step 4: Create new IDs for columns defined by the current node. + node_defined_mappings = { + old_id: next(id_generator) for old_id in root.node_defined_ids + } + new_root = new_root.remap_vars(node_defined_mappings) + + new_root._validate() + + # Step 5: Determine which mappings to propagate up to the parent. + if root.defines_namespace: + # If a node defines a new namespace (e.g., a join), mappings from its + # children are not visible to its parents. + mappings_for_parent = node_defined_mappings + else: + # Otherwise, pass up the combined mappings from children and the current node. + mappings_for_parent = downstream_mappings | node_defined_mappings + + return new_root, mappings_for_parent diff --git a/tests/unit/core/rewrite/conftest.py b/tests/unit/core/rewrite/conftest.py index 22b897f3bf..bbfbde46f3 100644 --- a/tests/unit/core/rewrite/conftest.py +++ b/tests/unit/core/rewrite/conftest.py @@ -34,7 +34,32 @@ @pytest.fixture def table(): - return TABLE + table_ref = google.cloud.bigquery.TableReference.from_string( + "project.dataset.table" + ) + schema = ( + google.cloud.bigquery.SchemaField("col_a", "INTEGER"), + google.cloud.bigquery.SchemaField("col_b", "INTEGER"), + ) + return google.cloud.bigquery.Table( + table_ref=table_ref, + schema=schema, + ) + + +@pytest.fixture +def table_too(): + table_ref = google.cloud.bigquery.TableReference.from_string( + "project.dataset.table_too" + ) + schema = ( + google.cloud.bigquery.SchemaField("col_a", "INTEGER"), + google.cloud.bigquery.SchemaField("col_c", "INTEGER"), + ) + return google.cloud.bigquery.Table( + table_ref=table_ref, + schema=schema, + ) @pytest.fixture @@ -49,3 +74,12 @@ def leaf(fake_session, table): table=table, schema=bigframes.core.schema.ArraySchema.from_bq_table(table), ).node + + +@pytest.fixture +def leaf_too(fake_session, table_too): + return core.ArrayValue.from_table( + session=fake_session, + table=table_too, + schema=bigframes.core.schema.ArraySchema.from_bq_table(table_too), + ).node diff --git a/tests/unit/core/rewrite/test_identifiers.py b/tests/unit/core/rewrite/test_identifiers.py index fd12df60a8..f95cd696d0 100644 --- a/tests/unit/core/rewrite/test_identifiers.py +++ b/tests/unit/core/rewrite/test_identifiers.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import typing import bigframes.core as core +import bigframes.core.expression as ex import bigframes.core.identifiers as identifiers import bigframes.core.nodes as nodes import bigframes.core.rewrite.identifiers as id_rewrite @@ -130,3 +132,24 @@ def test_remap_variables_concat_self_stability(leaf): assert new_node1 == new_node2 assert mapping1 == mapping2 + + +def test_remap_variables_in_node_converts_dag_to_tree(leaf, leaf_too): + # Create an InNode with the same child twice, should create a tree from a DAG + node = nodes.InNode( + left_child=leaf, + right_child=leaf_too, + left_col=ex.DerefOp(identifiers.ColumnId("col_a")), + right_col=ex.DerefOp(identifiers.ColumnId("col_a")), + indicator_col=identifiers.ColumnId("indicator"), + ) + + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node, _ = id_rewrite.remap_variables(node, id_generator) + new_node = typing.cast(nodes.InNode, new_node) + + left_col_id = new_node.left_col.id.name + right_col_id = new_node.right_col.id.name + assert left_col_id.startswith("id_") + assert right_col_id.startswith("id_") + assert left_col_id != right_col_id