Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,15 @@ def remap_vars(
def remap_refs(
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
) -> InNode:
return dataclasses.replace(self, left_col=self.left_col.remap_column_refs(mappings, allow_partial_bindings=True), right_col=self.right_col.remap_column_refs(mappings, allow_partial_bindings=True)) # type: ignore
return dataclasses.replace(
self,
left_col=self.left_col.remap_column_refs(
mappings, allow_partial_bindings=True
),
right_col=self.right_col.remap_column_refs(
mappings, allow_partial_bindings=True
),
) # type: ignore


@dataclasses.dataclass(frozen=True, eq=False)
Expand Down
89 changes: 63 additions & 26 deletions bigframes/core/rewrite/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import dataclasses
import typing

from bigframes.core import identifiers, nodes
Expand All @@ -26,32 +27,68 @@ def remap_variables(
nodes.BigFrameNode,
dict[identifiers.ColumnId, identifiers.ColumnId],
]:
"""Remaps `ColumnId`s in the BFET to produce deterministic and sequential UIDs.
"""Remaps `ColumnId`s in the expression tree to be deterministic and sequential.

Note: this will convert a DAG to a tree.
This function performs a post-order traversal. It recursively remaps children
nodes first, then remaps the current node's references and definitions.

Note: this will convert a DAG to a tree by duplicating shared nodes.

Args:
root: The root node of the expression tree.
id_generator: An iterator that yields new column IDs.

Returns:
A tuple of the new root node and a mapping from old to new column IDs
visible to the parent node.
"""
child_replacement_map = dict()
ref_mapping = dict()
# Sequential ids are assigned bottom-up left-to-right
# Step 1: Recursively remap children to get their new nodes and ID mappings.
new_child_nodes: list[nodes.BigFrameNode] = []
new_child_mappings: list[dict[identifiers.ColumnId, identifiers.ColumnId]] = []
for child in root.child_nodes:
new_child, child_var_mapping = remap_variables(child, id_generator=id_generator)
child_replacement_map[child] = new_child
ref_mapping.update(child_var_mapping)

# This is actually invalid until we've replaced all of children, refs and var defs
with_new_children = root.transform_children(
lambda node: child_replacement_map[node]
)

with_new_refs = with_new_children.remap_refs(ref_mapping)

node_var_mapping = {old_id: next(id_generator) for old_id in root.node_defined_ids}
with_new_vars = with_new_refs.remap_vars(node_var_mapping)
with_new_vars._validate()

return (
with_new_vars,
node_var_mapping
if root.defines_namespace
else (ref_mapping | node_var_mapping),
)
new_child, child_mappings = remap_variables(child, id_generator=id_generator)
new_child_nodes.append(new_child)
new_child_mappings.append(child_mappings)

# Step 2: Transform children to use their new nodes.
remapped_children: dict[nodes.BigFrameNode, nodes.BigFrameNode] = {
child: new_child for child, new_child in zip(root.child_nodes, new_child_nodes)
}
new_root = root.transform_children(lambda node: remapped_children[node])

# Step 3: Transform the current node using the mappings from its children.
downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = {
k: v for mapping in new_child_mappings for k, v in mapping.items()
}
if isinstance(new_root, nodes.InNode):
new_root = typing.cast(nodes.InNode, new_root)
new_root = dataclasses.replace(
new_root,
left_col=new_root.left_col.remap_column_refs(
new_child_mappings[0], allow_partial_bindings=True
),
right_col=new_root.right_col.remap_column_refs(
new_child_mappings[1], allow_partial_bindings=True
),
)
else:
new_root = new_root.remap_refs(downstream_mappings)

# Step 4: Create new IDs for columns defined by the current node.
node_defined_mappings = {
old_id: next(id_generator) for old_id in root.node_defined_ids
}
new_root = new_root.remap_vars(node_defined_mappings)

new_root._validate()

# Step 5: Determine which mappings to propagate up to the parent.
if root.defines_namespace:
# If a node defines a new namespace (e.g., a join), mappings from its
# children are not visible to its parents.
mappings_for_parent = node_defined_mappings
else:
# Otherwise, pass up the combined mappings from children and the current node.
mappings_for_parent = downstream_mappings | node_defined_mappings

return new_root, mappings_for_parent
36 changes: 35 additions & 1 deletion tests/unit/core/rewrite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,32 @@

@pytest.fixture
def table():
return TABLE
table_ref = google.cloud.bigquery.TableReference.from_string(
"project.dataset.table"
)
schema = (
google.cloud.bigquery.SchemaField("col_a", "INTEGER"),
google.cloud.bigquery.SchemaField("col_b", "INTEGER"),
)
return google.cloud.bigquery.Table(
table_ref=table_ref,
schema=schema,
)


@pytest.fixture
def table_too():
table_ref = google.cloud.bigquery.TableReference.from_string(
"project.dataset.table_too"
)
schema = (
google.cloud.bigquery.SchemaField("col_a", "INTEGER"),
google.cloud.bigquery.SchemaField("col_c", "INTEGER"),
)
return google.cloud.bigquery.Table(
table_ref=table_ref,
schema=schema,
)


@pytest.fixture
Expand All @@ -49,3 +74,12 @@ def leaf(fake_session, table):
table=table,
schema=bigframes.core.schema.ArraySchema.from_bq_table(table),
).node


@pytest.fixture
def leaf_too(fake_session, table_too):
return core.ArrayValue.from_table(
session=fake_session,
table=table_too,
schema=bigframes.core.schema.ArraySchema.from_bq_table(table_too),
).node
23 changes: 23 additions & 0 deletions tests/unit/core/rewrite/test_identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing

import bigframes.core as core
import bigframes.core.expression as ex
import bigframes.core.identifiers as identifiers
import bigframes.core.nodes as nodes
import bigframes.core.rewrite.identifiers as id_rewrite
Expand Down Expand Up @@ -130,3 +132,24 @@ def test_remap_variables_concat_self_stability(leaf):

assert new_node1 == new_node2
assert mapping1 == mapping2


def test_remap_variables_in_node_converts_dag_to_tree(leaf, leaf_too):
# Create an InNode with the same child twice, should create a tree from a DAG
node = nodes.InNode(
left_child=leaf,
right_child=leaf_too,
left_col=ex.DerefOp(identifiers.ColumnId("col_a")),
right_col=ex.DerefOp(identifiers.ColumnId("col_a")),
indicator_col=identifiers.ColumnId("indicator"),
)

id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
new_node, _ = id_rewrite.remap_variables(node, id_generator)
new_node = typing.cast(nodes.InNode, new_node)

left_col_id = new_node.left_col.id.name
right_col_id = new_node.right_col.id.name
assert left_col_id.startswith("id_")
assert right_col_id.startswith("id_")
assert left_col_id != right_col_id