Skip to content

Commit 244ff0d

Browse files
authored
Merge branch 'main' into sycai_ai_doc_fix
2 parents 090270b + e0aa9cc commit 244ff0d

File tree

9 files changed

+159
-6
lines changed

9 files changed

+159
-6
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,23 @@ def _(
151151
)
152152

153153

154+
@UNARY_OP_REGISTRATION.register(agg_ops.DiffOp)
155+
def _(
156+
op: agg_ops.DiffOp,
157+
column: typed_expr.TypedExpr,
158+
window: typing.Optional[window_spec.WindowSpec] = None,
159+
) -> sge.Expression:
160+
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
161+
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
162+
if column.dtype in (dtypes.BOOL_DTYPE, dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE):
163+
if column.dtype == dtypes.BOOL_DTYPE:
164+
return sge.NEQ(this=column.expr, expression=shifted)
165+
else:
166+
return sge.Sub(this=column.expr, expression=shifted)
167+
else:
168+
raise TypeError(f"Cannot perform diff on type {column.dtype}")
169+
170+
154171
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
155172
def _(
156173
op: agg_ops.MaxOp,
@@ -240,6 +257,27 @@ def _(
240257
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
241258

242259

260+
@UNARY_OP_REGISTRATION.register(agg_ops.ShiftOp)
261+
def _(
262+
op: agg_ops.ShiftOp,
263+
column: typed_expr.TypedExpr,
264+
window: typing.Optional[window_spec.WindowSpec] = None,
265+
) -> sge.Expression:
266+
if op.periods == 0: # No-op
267+
return column.expr
268+
if op.periods > 0:
269+
return apply_window_if_present(
270+
sge.func("LAG", column.expr, sge.convert(op.periods)),
271+
window,
272+
include_framing_clauses=False,
273+
)
274+
return apply_window_if_present(
275+
sge.func("LEAD", column.expr, sge.convert(-op.periods)),
276+
window,
277+
include_framing_clauses=False,
278+
)
279+
280+
243281
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
244282
def _(
245283
op: agg_ops.SumOp,

notebooks/generative_ai/ai_functions.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
"id": "aee05821",
5757
"metadata": {},
5858
"source": [
59-
"This notebook provides a brief introduction to how to use BigFrames AI functions"
59+
"This notebook provides a brief introduction to AI functions in BigQuery Dataframes."
6060
]
6161
},
6262
{
@@ -145,7 +145,7 @@
145145
"id": "b606c51f",
146146
"metadata": {},
147147
"source": [
148-
"You can also include additional model parameters into your function call, as long as they satisfy the structure of `generateContent` [request body format](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent#request-body). In the next example, you use `maxOutputTokens` to limite the length of the generated content."
148+
"You can also include additional model parameters into your function call, as long as they conform to the structure of `generateContent` [request body format](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent#request-body). In the next example, you use `maxOutputTokens` to limit the length of the generated content."
149149
]
150150
},
151151
{
@@ -186,7 +186,7 @@
186186
"source": [
187187
"The answers are cut short as expected.\n",
188188
"\n",
189-
"In addition to `ai.generate`, you can use `ai.generate_bool`, `ai.generate_int`, and `ai.generate_double` for other type of outputs."
189+
"In addition to `ai.generate`, you can use `ai.generate_bool`, `ai.generate_int`, and `ai.generate_double` for other output types."
190190
]
191191
},
192192
{
@@ -196,7 +196,7 @@
196196
"source": [
197197
"## ai.if_\n",
198198
"\n",
199-
"`ai.if_` generates a series of booleans, unlike `ai.generate_bool` where you get a series of structs. It's a handy tool for filtering your data. not only because it directly returns a boolean, but also because it provides more optimization during data processing. Here is an example of using `ai.if_`:"
199+
"`ai.if_` generates a series of booleans. It's a handy tool for joining and filtering your data, not only because it directly returns boolean values, but also because it provides more optimization during data processing. Here is an example of using `ai.if_`:"
200200
]
201201
},
202202
{
@@ -284,7 +284,7 @@
284284
"id": "63b5a59f",
285285
"metadata": {},
286286
"source": [
287-
"`ai.score` ranks your input based on the prompt. You can then sort your data based on their ranks. For example:"
287+
"`ai.score` ranks your input based on the prompt and assigns a double value (i.e. a score) to each item. You can then sort your data based on their scores. For example:"
288288
]
289289
},
290290
{
@@ -460,7 +460,7 @@
460460
"id": "9e4037bc",
461461
"metadata": {},
462462
"source": [
463-
"Note that this function can only return the values that are present in your provided categories. If your categories do not cover all cases, your will get wrong answers:"
463+
"Note that this function can only return the values that are provided in the `categories` argument. If your categories do not cover all cases, your may get wrong answers:"
464464
]
465465
},
466466
{
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0` <> LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `diff_bool`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0` - LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `diff_int`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `lag`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
LEAD(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `lead`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0` AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `noop`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,28 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
127127
snapshot.assert_match(sql, "out.sql")
128128

129129

130+
def test_diff(scalar_types_df: bpd.DataFrame, snapshot):
131+
# Test integer
132+
int_col = "int64_col"
133+
bf_df_int = scalar_types_df[[int_col]]
134+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(int_col),))
135+
int_op = agg_exprs.UnaryAggregation(
136+
agg_ops.DiffOp(periods=1), expression.deref(int_col)
137+
)
138+
int_sql = _apply_unary_window_op(bf_df_int, int_op, window, "diff_int")
139+
snapshot.assert_match(int_sql, "diff_int.sql")
140+
141+
# Test boolean
142+
bool_col = "bool_col"
143+
bf_df_bool = scalar_types_df[[bool_col]]
144+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(bool_col),))
145+
bool_op = agg_exprs.UnaryAggregation(
146+
agg_ops.DiffOp(periods=1), expression.deref(bool_col)
147+
)
148+
bool_sql = _apply_unary_window_op(bf_df_bool, bool_op, window, "diff_bool")
149+
snapshot.assert_match(bool_sql, "diff_bool.sql")
150+
151+
130152
def test_first(scalar_types_df: bpd.DataFrame, snapshot):
131153
if sys.version_info < (3, 12):
132154
pytest.skip(
@@ -271,6 +293,33 @@ def test_rank(scalar_types_df: bpd.DataFrame, snapshot):
271293
snapshot.assert_match(sql, "out.sql")
272294

273295

296+
def test_shift(scalar_types_df: bpd.DataFrame, snapshot):
297+
col_name = "int64_col"
298+
bf_df = scalar_types_df[[col_name]]
299+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
300+
301+
# Test lag
302+
lag_op = agg_exprs.UnaryAggregation(
303+
agg_ops.ShiftOp(periods=1), expression.deref(col_name)
304+
)
305+
lag_sql = _apply_unary_window_op(bf_df, lag_op, window, "lag")
306+
snapshot.assert_match(lag_sql, "lag.sql")
307+
308+
# Test lead
309+
lead_op = agg_exprs.UnaryAggregation(
310+
agg_ops.ShiftOp(periods=-1), expression.deref(col_name)
311+
)
312+
lead_sql = _apply_unary_window_op(bf_df, lead_op, window, "lead")
313+
snapshot.assert_match(lead_sql, "lead.sql")
314+
315+
# Test no-op
316+
noop_op = agg_exprs.UnaryAggregation(
317+
agg_ops.ShiftOp(periods=0), expression.deref(col_name)
318+
)
319+
noop_sql = _apply_unary_window_op(bf_df, noop_op, window, "noop")
320+
snapshot.assert_match(noop_sql, "noop.sql")
321+
322+
274323
def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
275324
bf_df = scalar_types_df[["int64_col", "bool_col"]]
276325
agg_ops_map = {

tests/unit/session/test_read_gbq_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ def test_read_gbq_query_sets_destination_table():
3535

3636
assert query == "SELECT 'my-test-query';"
3737
assert config.destination is not None
38+
session.close()

0 commit comments

Comments
 (0)