From 943c0ad1d267914ec3cb680612acf2d6d4d35e7c Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 11 Jul 2025 19:41:39 +0000 Subject: [PATCH] test: Run some engine tests on sqlglot compiler --- bigframes/session/direct_gbq_execution.py | 15 ++++++++++++--- tests/system/small/engines/conftest.py | 6 +++++- tests/system/small/engines/test_join.py | 6 +++--- tests/system/small/engines/test_read_local.py | 2 ++ tests/system/small/engines/test_sorting.py | 8 ++++---- 5 files changed, 26 insertions(+), 11 deletions(-) diff --git a/bigframes/session/direct_gbq_execution.py b/bigframes/session/direct_gbq_execution.py index 1d46192ac3..ff91747a62 100644 --- a/bigframes/session/direct_gbq_execution.py +++ b/bigframes/session/direct_gbq_execution.py @@ -13,13 +13,14 @@ # limitations under the License. from __future__ import annotations -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple from google.cloud import bigquery import google.cloud.bigquery.job as bq_job import google.cloud.bigquery.table as bq_table from bigframes.core import compile, nodes +from bigframes.core.compile import sqlglot from bigframes.session import executor, semi_executor import bigframes.session._io.bigquery as bq_io @@ -29,8 +30,15 @@ # or record metrics. Also avoids caching, and most pre-compile rewrites, to better serve as a # reference for validating more complex executors. class DirectGbqExecutor(semi_executor.SemiExecutor): - def __init__(self, bqclient: bigquery.Client): + def __init__( + self, bqclient: bigquery.Client, compiler: Literal["ibis", "sqlglot"] = "ibis" + ): self.bqclient = bqclient + self._compile_fn = ( + compile.compile_sql + if compiler == "ibis" + else sqlglot.SQLGlotCompiler()._compile_sql + ) def execute( self, @@ -42,9 +50,10 @@ def execute( # TODO(swast): plumb through the api_name of the user-facing api that # caused this query. - compiled = compile.compile_sql( + compiled = self._compile_fn( compile.CompileRequest(plan, sort_rows=ordered, peek_count=peek) ) + iterator, query_job = self._run_execute_query( sql=compiled.sql, ) diff --git a/tests/system/small/engines/conftest.py b/tests/system/small/engines/conftest.py index 249bd59260..4f0f875b34 100644 --- a/tests/system/small/engines/conftest.py +++ b/tests/system/small/engines/conftest.py @@ -44,7 +44,7 @@ def fake_session() -> Generator[bigframes.Session, None, None]: yield session -@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq"]) +@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq", "bq-sqlglot"]) def engine(request, bigquery_client: bigquery.Client) -> semi_executor.SemiExecutor: if request.param == "pyarrow": return local_scan_executor.LocalScanExecutor() @@ -52,6 +52,10 @@ def engine(request, bigquery_client: bigquery.Client) -> semi_executor.SemiExecu return polars_executor.PolarsExecutor() if request.param == "bq": return direct_gbq_execution.DirectGbqExecutor(bigquery_client) + if request.param == "bq-sqlglot": + return direct_gbq_execution.DirectGbqExecutor( + bigquery_client, compiler="sqlglot" + ) raise ValueError(f"Unrecognized param: {request.param}") diff --git a/tests/system/small/engines/test_join.py b/tests/system/small/engines/test_join.py index e1f9fe6070..402a41134b 100644 --- a/tests/system/small/engines/test_join.py +++ b/tests/system/small/engines/test_join.py @@ -27,7 +27,7 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) @pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"]) def test_engines_join_on_key( scalars_array_value: array_value.ArrayValue, @@ -41,7 +41,7 @@ def test_engines_join_on_key( assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) @pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"]) def test_engines_join_on_coerced_key( scalars_array_value: array_value.ArrayValue, @@ -80,7 +80,7 @@ def test_engines_join_multi_key( assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_cross_join( scalars_array_value: array_value.ArrayValue, engine, diff --git a/tests/system/small/engines/test_read_local.py b/tests/system/small/engines/test_read_local.py index 82af7c984d..bf1a10beec 100644 --- a/tests/system/small/engines/test_read_local.py +++ b/tests/system/small/engines/test_read_local.py @@ -88,6 +88,8 @@ def test_engines_read_local_w_zero_row_source( assert_equivalence_execution(local_node, REFERENCE_ENGINE, engine) +# TODO: Fix sqlglot impl +@pytest.mark.parametrize("engine", ["polars", "bq", "pyarrow"], indirect=True) def test_engines_read_local_w_nested_source( fake_session: bigframes.Session, nested_data_source: local_data.ManagedArrowTable, diff --git a/tests/system/small/engines/test_sorting.py b/tests/system/small/engines/test_sorting.py index d1929afa44..ec1c0d95ee 100644 --- a/tests/system/small/engines/test_sorting.py +++ b/tests/system/small/engines/test_sorting.py @@ -25,7 +25,7 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_reverse( scalars_array_value: array_value.ArrayValue, engine, @@ -34,7 +34,7 @@ def test_engines_reverse( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_double_reverse( scalars_array_value: array_value.ArrayValue, engine, @@ -43,7 +43,7 @@ def test_engines_double_reverse( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) @pytest.mark.parametrize( "sort_col", [ @@ -70,7 +70,7 @@ def test_engines_sort_over_column( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_sort_multi_column_refs( scalars_array_value: array_value.ArrayValue, engine,