From 58147277a754d2cdbc80a13d0a8c31defc9e03fc Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Thu, 18 Sep 2025 22:45:16 +0000 Subject: [PATCH 1/4] chore: implement ai.generate_bool in SQLGlot compiler --- bigframes/core/compile/sqlglot/__init__.py | 1 + .../compile/sqlglot/expressions/ai_ops.py | 65 +++++++++++++++++++ .../test_ai_ops/test_ai_generate_bool/out.sql | 18 +++++ .../sqlglot/expressions/test_ai_ops.py | 41 ++++++++++++ 4 files changed, 125 insertions(+) create mode 100644 bigframes/core/compile/sqlglot/expressions/ai_ops.py create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py diff --git a/bigframes/core/compile/sqlglot/__init__.py b/bigframes/core/compile/sqlglot/__init__.py index 5fe8099043..50b9663aaa 100644 --- a/bigframes/core/compile/sqlglot/__init__.py +++ b/bigframes/core/compile/sqlglot/__init__.py @@ -14,6 +14,7 @@ from __future__ import annotations from bigframes.core.compile.sqlglot.compiler import SQLGlotCompiler +import bigframes.core.compile.sqlglot.expressions.ai_ops # noqa: F401 import bigframes.core.compile.sqlglot.expressions.array_ops # noqa: F401 import bigframes.core.compile.sqlglot.expressions.binary_compiler # noqa: F401 import bigframes.core.compile.sqlglot.expressions.blob_ops # noqa: F401 diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py new file mode 100644 index 0000000000..8621bed1a8 --- /dev/null +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -0,0 +1,65 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence + +import sqlglot.expressions as sge + +from bigframes import operations as ops +from bigframes.core.compile.sqlglot import scalar_compiler +from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr + +register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op + + +@register_nary_op(ops.AIGenerateBool, pass_op=True) +def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression: + + prompt: list[str | sge.Expression] = [] + column_ref_idx = 0 + + for elem in op.prompt_context: + if elem is None: + prompt.append(exprs[column_ref_idx].expr) + else: + prompt.append(sge.Literal.string(elem)) + + args = [sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt))] + + args.append( + sge.Kwarg(this="connection_id", expression=sge.Literal.string(op.connection_id)) + ) + + if op.endpoint is not None: + args.append( + sge.Kwarg(this="endpoint", expression=sge.Literal.string(op.endpoint)) + ) + + args.append( + sge.Kwarg( + this="request_type", expression=sge.Literal.string(op.request_type.upper()) + ) + ) + + if op.model_params is not None: + args.append( + sge.Kwarg( + this="model_params", + expression=sge.JSON(this=sge.Literal.string(op.model_params)), + ) + ) + + return sge.func("AI.GENERATE_BOOL", *args) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql new file mode 100644 index 0000000000..fca2b965bf --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.GENERATE_BOOL( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'test_connection_id', + request_type => 'SHARED', + model_params => JSON '{}' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py new file mode 100644 index 0000000000..009c9315d5 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -0,0 +1,41 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import pytest + +from bigframes import dataframe +from bigframes import operations as ops +from bigframes.testing import utils + +pytest.importorskip("pytest_snapshot") + + +def test_ai_generate_bool(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIGenerateBool( + prompt_context=(None, " is the same as ", None), + connection_id="test_connection_id", + endpoint=None, + request_type="shared", + model_params=json.dumps(dict()), + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") From 1c473021fa2403706f3425ae85e679fd1c17cba9 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Thu, 18 Sep 2025 22:48:44 +0000 Subject: [PATCH 2/4] fix lint --- bigframes/core/compile/sqlglot/expressions/ai_ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 8621bed1a8..e666f6b763 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -14,8 +14,6 @@ from __future__ import annotations -from typing import Sequence - import sqlglot.expressions as sge from bigframes import operations as ops From eab92f90c09a64c3fce15bed7be0d52a47fc6a04 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Fri, 19 Sep 2025 17:18:50 +0000 Subject: [PATCH 3/4] fix test --- .../test_ai_ops/test_ai_generate_bool/out.sql | 4 +-- .../out.sql | 18 +++++++++++++ .../sqlglot/expressions/test_ai_ops.py | 26 +++++++++++++++++++ 3 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql index fca2b965bf..584ccd9ce1 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql @@ -8,8 +8,8 @@ WITH `bfcte_0` AS ( AI.GENERATE_BOOL( prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), connection_id => 'test_connection_id', - request_type => 'SHARED', - model_params => JSON '{}' + endpoint => 'gemini-2.5-flash', + request_type => 'SHARED' ) AS `bfcol_1` FROM `bfcte_0` ) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql new file mode 100644 index 0000000000..fca2b965bf --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.GENERATE_BOOL( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'test_connection_id', + request_type => 'SHARED', + model_params => JSON '{}' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 009c9315d5..15b9ae516b 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import sys import pytest @@ -26,6 +27,31 @@ def test_ai_generate_bool(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" + op = ops.AIGenerateBool( + prompt_context=(None, " is the same as ", None), + connection_id="test_connection_id", + endpoint="gemini-2.5-flash", + request_type="shared", + model_params=None, + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_generate_bool_with_model_param( + scalar_types_df: dataframe.DataFrame, snapshot +): + if sys.version_info < (3, 10): + pytest.skip( + "Skip test because SQLGLot cannot compile model params to JSON at this env." + ) + + col_name = "string_col" + op = ops.AIGenerateBool( prompt_context=(None, " is the same as ", None), connection_id="test_connection_id", From fb6c364b8ac6cdb210ef9762de6eb3a7292f23e7 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Mon, 22 Sep 2025 20:25:35 +0000 Subject: [PATCH 4/4] add comment on sge.JSON --- bigframes/core/compile/sqlglot/expressions/ai_ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index e666f6b763..8395461575 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -56,6 +56,8 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression: args.append( sge.Kwarg( this="model_params", + # sge.JSON requires a newer SQLGlot version than 23.6.3. + # PARSE_JSON won't work as the function requires a JSON literal. expression=sge.JSON(this=sge.Literal.string(op.model_params)), ) )