Skip to content

Commit eab92f9

Browse files
committed
fix test
1 parent 1c47302 commit eab92f9

File tree

3 files changed

+46
-2
lines changed
  • tests/unit/core/compile/sqlglot/expressions

3 files changed

+46
-2
lines changed

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ WITH `bfcte_0` AS (
88
AI.GENERATE_BOOL(
99
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
1010
connection_id => 'test_connection_id',
11-
request_type => 'SHARED',
12-
model_params => JSON '{}'
11+
endpoint => 'gemini-2.5-flash',
12+
request_type => 'SHARED'
1313
) AS `bfcol_1`
1414
FROM `bfcte_0`
1515
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE_BOOL(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'test_connection_id',
11+
request_type => 'SHARED',
12+
model_params => JSON '{}'
13+
) AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `result`
18+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16+
import sys
1617

1718
import pytest
1819

@@ -26,6 +27,31 @@
2627
def test_ai_generate_bool(scalar_types_df: dataframe.DataFrame, snapshot):
2728
col_name = "string_col"
2829

30+
op = ops.AIGenerateBool(
31+
prompt_context=(None, " is the same as ", None),
32+
connection_id="test_connection_id",
33+
endpoint="gemini-2.5-flash",
34+
request_type="shared",
35+
model_params=None,
36+
)
37+
38+
sql = utils._apply_unary_ops(
39+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
40+
)
41+
42+
snapshot.assert_match(sql, "out.sql")
43+
44+
45+
def test_ai_generate_bool_with_model_param(
46+
scalar_types_df: dataframe.DataFrame, snapshot
47+
):
48+
if sys.version_info < (3, 10):
49+
pytest.skip(
50+
"Skip test because SQLGLot cannot compile model params to JSON at this env."
51+
)
52+
53+
col_name = "string_col"
54+
2955
op = ops.AIGenerateBool(
3056
prompt_context=(None, " is the same as ", None),
3157
connection_id="test_connection_id",

0 commit comments

Comments
 (0)