Skip to content

Commit 89afa5b

Browse files
committed
refactor: add agg_ops.CutOp to the sqlglot compiler
1 parent 08c0c0c commit 89afa5b

File tree

3 files changed

+266
-0
lines changed

3 files changed

+266
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,164 @@ def _(
111111
return apply_window_if_present(sge.func("COUNT", column.expr), window)
112112

113113

114+
@UNARY_OP_REGISTRATION.register(agg_ops.CutOp)
115+
def _(
116+
op: agg_ops.CutOp,
117+
column: typed_expr.TypedExpr,
118+
window: typing.Optional[window_spec.WindowSpec] = None,
119+
) -> sge.Expression:
120+
if isinstance(op.bins, int):
121+
case_expr = _cut_ops_w_int_bins(op, column, op.bins, window)
122+
else: # Interpret as intervals
123+
case_expr = _cut_ops_w_intervals(op, column, op.bins, window)
124+
return apply_window_if_present(case_expr, window)
125+
126+
127+
def _cut_ops_w_int_bins(
128+
op: agg_ops.CutOp,
129+
column: typed_expr.TypedExpr,
130+
bins: int,
131+
window: typing.Optional[window_spec.WindowSpec] = None,
132+
) -> sge.Case:
133+
case_expr = sge.Case()
134+
col_min = apply_window_if_present(
135+
sge.func("MIN", column.expr), window or window_spec.WindowSpec()
136+
)
137+
col_max = apply_window_if_present(
138+
sge.func("MAX", column.expr), window or window_spec.WindowSpec()
139+
)
140+
adj: sge.Expression = sge.Sub(this=col_max, expression=col_min) * sge.convert(0.001)
141+
bin_width: sge.Expression = sge.func(
142+
"IEEE_DIVIDE",
143+
sge.Sub(this=col_max, expression=col_min),
144+
sge.convert(bins),
145+
)
146+
147+
for this_bin in range(bins):
148+
value: sge.Expression
149+
if op.labels is False:
150+
value = ir._literal(this_bin, dtypes.INT_DTYPE)
151+
elif isinstance(op.labels, typing.Iterable):
152+
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
153+
else:
154+
left_adj: sge.Expression = (
155+
adj if this_bin == 0 and op.right else sge.convert(0)
156+
)
157+
right_adj: sge.Expression = (
158+
adj if this_bin == bins - 1 and not op.right else sge.convert(0)
159+
)
160+
161+
left: sge.Expression = (
162+
col_min + sge.convert(this_bin) * bin_width - left_adj
163+
)
164+
right: sge.Expression = (
165+
col_min + sge.convert(this_bin + 1) * bin_width + right_adj
166+
)
167+
if op.right:
168+
value = sge.Struct(
169+
expressions=[
170+
sge.PropertyEQ(
171+
this=sge.Identifier(this="left_exclusive", quoted=True),
172+
expression=left,
173+
),
174+
sge.PropertyEQ(
175+
this=sge.Identifier(this="right_inclusive", quoted=True),
176+
expression=right,
177+
),
178+
]
179+
)
180+
else:
181+
value = sge.Struct(
182+
expressions=[
183+
sge.PropertyEQ(
184+
this=sge.Identifier(this="left_inclusive", quoted=True),
185+
expression=left,
186+
),
187+
sge.PropertyEQ(
188+
this=sge.Identifier(this="right_exclusive", quoted=True),
189+
expression=right,
190+
),
191+
]
192+
)
193+
194+
condition: sge.Expression
195+
if this_bin == bins - 1:
196+
condition = sge.Is(this=column.expr, expression=sge.Not(this=sge.Null()))
197+
else:
198+
if op.right:
199+
condition = sge.LTE(
200+
this=column.expr,
201+
expression=(col_min + sge.convert(this_bin + 1) * bin_width),
202+
)
203+
else:
204+
condition = sge.LT(
205+
this=column.expr,
206+
expression=(col_min + sge.convert(this_bin + 1) * bin_width),
207+
)
208+
case_expr = case_expr.when(condition, value)
209+
return case_expr
210+
211+
212+
def _cut_ops_w_intervals(
213+
op: agg_ops.CutOp,
214+
column: typed_expr.TypedExpr,
215+
bins: typing.Iterable[typing.Tuple[typing.Any, typing.Any]],
216+
window: typing.Optional[window_spec.WindowSpec] = None,
217+
) -> sge.Case:
218+
case_expr = sge.Case()
219+
for this_bin, interval in enumerate(bins):
220+
left: sge.Expression = ir._literal(
221+
interval[0], dtypes.infer_literal_type(interval[0])
222+
)
223+
right: sge.Expression = ir._literal(
224+
interval[1], dtypes.infer_literal_type(interval[1])
225+
)
226+
condition: sge.Expression
227+
if op.right:
228+
condition = sge.And(
229+
this=sge.GT(this=column.expr, expression=left),
230+
expression=sge.LTE(this=column.expr, expression=right),
231+
)
232+
else:
233+
condition = sge.And(
234+
this=sge.GTE(this=column.expr, expression=left),
235+
expression=sge.LT(this=column.expr, expression=right),
236+
)
237+
238+
value: sge.Expression
239+
if op.labels is False:
240+
value = ir._literal(this_bin, dtypes.INT_DTYPE)
241+
elif isinstance(op.labels, typing.Iterable):
242+
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
243+
else:
244+
if op.right:
245+
value = sge.Struct(
246+
expressions=[
247+
sge.PropertyEQ(
248+
this=sge.Identifier(this="left_exclusive"), expression=left
249+
),
250+
sge.PropertyEQ(
251+
this=sge.Identifier(this="right_inclusive"),
252+
expression=right,
253+
),
254+
]
255+
)
256+
else:
257+
value = sge.Struct(
258+
expressions=[
259+
sge.PropertyEQ(
260+
this=sge.Identifier(this="left_inclusive"), expression=left
261+
),
262+
sge.PropertyEQ(
263+
this=sge.Identifier(this="right_exclusive"),
264+
expression=right,
265+
),
266+
]
267+
)
268+
case_expr = case_expr.when(condition, value)
269+
return case_expr
270+
271+
114272
@UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp)
115273
def _(
116274
op: agg_ops.DateSeriesDiffOp,
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
CASE
8+
WHEN `int64_col` <= MIN(`int64_col`) OVER () + (
9+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
10+
)
11+
THEN STRUCT(
12+
(
13+
MIN(`int64_col`) OVER () + (
14+
0 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
15+
)
16+
) - (
17+
(
18+
MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER ()
19+
) * 0.001
20+
) AS `left_exclusive`,
21+
MIN(`int64_col`) OVER () + (
22+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
23+
) + 0 AS `right_inclusive`
24+
)
25+
WHEN `int64_col` <= MIN(`int64_col`) OVER () + (
26+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
27+
)
28+
THEN STRUCT(
29+
(
30+
MIN(`int64_col`) OVER () + (
31+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
32+
)
33+
) - 0 AS `left_exclusive`,
34+
MIN(`int64_col`) OVER () + (
35+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
36+
) + 0 AS `right_inclusive`
37+
)
38+
WHEN `int64_col` IS NOT NULL
39+
THEN STRUCT(
40+
(
41+
MIN(`int64_col`) OVER () + (
42+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
43+
)
44+
) - 0 AS `left_exclusive`,
45+
MIN(`int64_col`) OVER () + (
46+
3 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
47+
) + 0 AS `right_inclusive`
48+
)
49+
END AS `bfcol_1`,
50+
CASE
51+
WHEN `int64_col` > 0 AND `int64_col` <= 1
52+
THEN STRUCT(0 AS left_exclusive, 1 AS right_inclusive)
53+
WHEN `int64_col` > 1 AND `int64_col` <= 2
54+
THEN STRUCT(1 AS left_exclusive, 2 AS right_inclusive)
55+
END AS `bfcol_2`,
56+
CASE
57+
WHEN `int64_col` < MIN(`int64_col`) OVER () + (
58+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
59+
)
60+
THEN 'a'
61+
WHEN `int64_col` < MIN(`int64_col`) OVER () + (
62+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
63+
)
64+
THEN 'b'
65+
WHEN `int64_col` IS NOT NULL
66+
THEN 'c'
67+
END AS `bfcol_3`,
68+
CASE
69+
WHEN `int64_col` > 0 AND `int64_col` <= 1
70+
THEN 0
71+
WHEN `int64_col` > 1 AND `int64_col` <= 2
72+
THEN 1
73+
END AS `bfcol_4`
74+
FROM `bfcte_0`
75+
)
76+
SELECT
77+
`bfcol_1` AS `int_bins`,
78+
`bfcol_2` AS `interval_bins`,
79+
`bfcol_3` AS `int_bins_labels`,
80+
`bfcol_4` AS `interval_bins_labels`
81+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,33 @@ def test_count(scalar_types_df: bpd.DataFrame, snapshot):
174174
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")
175175

176176

177+
def test_cut(scalar_types_df: bpd.DataFrame, snapshot):
178+
col_name = "int64_col"
179+
bf_df = scalar_types_df[[col_name]]
180+
agg_ops_map = {
181+
"int_bins": agg_exprs.UnaryAggregation(
182+
agg_ops.CutOp(bins=3, right=True, labels=None), expression.deref(col_name)
183+
),
184+
"interval_bins": agg_exprs.UnaryAggregation(
185+
agg_ops.CutOp(bins=((0, 1), (1, 2)), right=True, labels=None),
186+
expression.deref(col_name),
187+
),
188+
"int_bins_labels": agg_exprs.UnaryAggregation(
189+
agg_ops.CutOp(bins=3, labels=("a", "b", "c"), right=False),
190+
expression.deref(col_name),
191+
),
192+
"interval_bins_labels": agg_exprs.UnaryAggregation(
193+
agg_ops.CutOp(bins=((0, 1), (1, 2)), labels=False, right=True),
194+
expression.deref(col_name),
195+
),
196+
}
197+
sql = _apply_unary_agg_ops(
198+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
199+
)
200+
201+
snapshot.assert_match(sql, "out.sql")
202+
203+
177204
def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
178205
col_name = "int64_col"
179206
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)