From 660ecc89854532313a68ac25e9ca8fac28c780dd Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Tue, 8 Jul 2025 23:42:32 +0000 Subject: [PATCH 1/2] feat: Add concat pushdown for hybrid engine --- bigframes/core/compile/polars/compiler.py | 5 +++ bigframes/core/nodes.py | 2 +- bigframes/session/polars_executor.py | 1 + tests/system/small/engines/test_concat.py | 49 +++++++++++++++++++++++ 4 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 tests/system/small/engines/test_concat.py diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 40037735d4..dfa2ebc818 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -547,6 +547,11 @@ def compile_concat(self, node: nodes.ConcatNode): child_frames = [ frame.rename( {col: id.sql for col, id in zip(frame.columns, node.output_ids)} + ).cast( + { + field.id.sql: _bigframes_dtype_to_polars_dtype(field.dtype) + for field in node.fields + } ) for frame in child_frames ] diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 205621fee2..690ed545d6 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -424,7 +424,7 @@ def remap_refs( @dataclasses.dataclass(frozen=True, eq=False) class ConcatNode(BigFrameNode): - # TODO: Explcitly map column ids from each child + # TODO: Explcitly map column ids from each child? children: Tuple[BigFrameNode, ...] output_ids: Tuple[identifiers.ColumnId, ...] diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 28ab421905..8f669901a4 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -36,6 +36,7 @@ nodes.SliceNode, nodes.AggregateNode, nodes.FilterNode, + nodes.ConcatNode, ) _COMPATIBLE_SCALAR_OPS = ( diff --git a/tests/system/small/engines/test_concat.py b/tests/system/small/engines/test_concat.py new file mode 100644 index 0000000000..10b54471aa --- /dev/null +++ b/tests/system/small/engines/test_concat.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from bigframes.core import array_value, ordering +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_concat_self( + scalars_array_value: array_value.ArrayValue, + engine, +): + catted = scalars_array_value.concat([scalars_array_value, scalars_array_value]) + assert_equivalence_execution(catted.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_concat_filtered_sorted( + scalars_array_value: array_value.ArrayValue, + engine, +): + input_1 = scalars_array_value.select_columns(["float64_col", "int64_col"]).order_by( + [ordering.ascending_over("int64_col")] + ) + input_2 = scalars_array_value.filter_by_id("bool_col").select_columns( + ["float64_col", "int64_too"] + ) + + catted = input_1.concat([input_2, input_1, input_2]) + assert_equivalence_execution(catted.node, REFERENCE_ENGINE, engine) From 69563beb86b14e3a0f374b4bcfa9d79cb1171644 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Wed, 9 Jul 2025 21:35:08 +0000 Subject: [PATCH 2/2] style --- tests/system/small/engines/test_concat.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/system/small/engines/test_concat.py b/tests/system/small/engines/test_concat.py index 10b54471aa..e10570fab2 100644 --- a/tests/system/small/engines/test_concat.py +++ b/tests/system/small/engines/test_concat.py @@ -29,8 +29,9 @@ def test_engines_concat_self( scalars_array_value: array_value.ArrayValue, engine, ): - catted = scalars_array_value.concat([scalars_array_value, scalars_array_value]) - assert_equivalence_execution(catted.node, REFERENCE_ENGINE, engine) + result = scalars_array_value.concat([scalars_array_value, scalars_array_value]) + + assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) @pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) @@ -45,5 +46,6 @@ def test_engines_concat_filtered_sorted( ["float64_col", "int64_too"] ) - catted = input_1.concat([input_2, input_1, input_2]) - assert_equivalence_execution(catted.node, REFERENCE_ENGINE, engine) + result = input_1.concat([input_2, input_1, input_2]) + + assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)