Skip to content

Commit 5814727

Browse files
committed
chore: implement ai.generate_bool in SQLGlot compiler
1 parent 9dc9695 commit 5814727

File tree

4 files changed

+125
-0
lines changed

4 files changed

+125
-0
lines changed

bigframes/core/compile/sqlglot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
from bigframes.core.compile.sqlglot.compiler import SQLGlotCompiler
17+
import bigframes.core.compile.sqlglot.expressions.ai_ops # noqa: F401
1718
import bigframes.core.compile.sqlglot.expressions.array_ops # noqa: F401
1819
import bigframes.core.compile.sqlglot.expressions.binary_compiler # noqa: F401
1920
import bigframes.core.compile.sqlglot.expressions.blob_ops # noqa: F401
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Sequence
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes import operations as ops
22+
from bigframes.core.compile.sqlglot import scalar_compiler
23+
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
24+
25+
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
26+
27+
28+
@register_nary_op(ops.AIGenerateBool, pass_op=True)
29+
def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression:
30+
31+
prompt: list[str | sge.Expression] = []
32+
column_ref_idx = 0
33+
34+
for elem in op.prompt_context:
35+
if elem is None:
36+
prompt.append(exprs[column_ref_idx].expr)
37+
else:
38+
prompt.append(sge.Literal.string(elem))
39+
40+
args = [sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt))]
41+
42+
args.append(
43+
sge.Kwarg(this="connection_id", expression=sge.Literal.string(op.connection_id))
44+
)
45+
46+
if op.endpoint is not None:
47+
args.append(
48+
sge.Kwarg(this="endpoint", expression=sge.Literal.string(op.endpoint))
49+
)
50+
51+
args.append(
52+
sge.Kwarg(
53+
this="request_type", expression=sge.Literal.string(op.request_type.upper())
54+
)
55+
)
56+
57+
if op.model_params is not None:
58+
args.append(
59+
sge.Kwarg(
60+
this="model_params",
61+
expression=sge.JSON(this=sge.Literal.string(op.model_params)),
62+
)
63+
)
64+
65+
return sge.func("AI.GENERATE_BOOL", *args)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE_BOOL(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'test_connection_id',
11+
request_type => 'SHARED',
12+
model_params => JSON '{}'
13+
) AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `result`
18+
FROM `bfcte_1`
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
17+
import pytest
18+
19+
from bigframes import dataframe
20+
from bigframes import operations as ops
21+
from bigframes.testing import utils
22+
23+
pytest.importorskip("pytest_snapshot")
24+
25+
26+
def test_ai_generate_bool(scalar_types_df: dataframe.DataFrame, snapshot):
27+
col_name = "string_col"
28+
29+
op = ops.AIGenerateBool(
30+
prompt_context=(None, " is the same as ", None),
31+
connection_id="test_connection_id",
32+
endpoint=None,
33+
request_type="shared",
34+
model_params=json.dumps(dict()),
35+
)
36+
37+
sql = utils._apply_unary_ops(
38+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
39+
)
40+
41+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)