Skip to content

Commit 8509966

Browse files
committed
refactor: add agg_ops.CutOp to the sqlglot compiler
1 parent f73fb98 commit 8509966

File tree

3 files changed

+266
-0
lines changed

3 files changed

+266
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,164 @@ def _(
9898
return apply_window_if_present(sge.func("COUNT", column.expr), window)
9999

100100

101+
@UNARY_OP_REGISTRATION.register(agg_ops.CutOp)
102+
def _(
103+
op: agg_ops.CutOp,
104+
column: typed_expr.TypedExpr,
105+
window: typing.Optional[window_spec.WindowSpec] = None,
106+
) -> sge.Expression:
107+
if isinstance(op.bins, int):
108+
case_expr = _cut_ops_w_int_bins(op, column, op.bins, window)
109+
else: # Interpret as intervals
110+
case_expr = _cut_ops_w_intervals(op, column, op.bins, window)
111+
return apply_window_if_present(case_expr, window)
112+
113+
114+
def _cut_ops_w_int_bins(
115+
op: agg_ops.CutOp,
116+
column: typed_expr.TypedExpr,
117+
bins: int,
118+
window: typing.Optional[window_spec.WindowSpec] = None,
119+
) -> sge.Case:
120+
case_expr = sge.Case()
121+
col_min = apply_window_if_present(
122+
sge.func("MIN", column.expr), window or window_spec.WindowSpec()
123+
)
124+
col_max = apply_window_if_present(
125+
sge.func("MAX", column.expr), window or window_spec.WindowSpec()
126+
)
127+
adj: sge.Expression = sge.Sub(this=col_max, expression=col_min) * sge.convert(0.001)
128+
bin_width: sge.Expression = sge.func(
129+
"IEEE_DIVIDE",
130+
sge.Sub(this=col_max, expression=col_min),
131+
sge.convert(bins),
132+
)
133+
134+
for this_bin in range(bins):
135+
value: sge.Expression
136+
if op.labels is False:
137+
value = ir._literal(this_bin, dtypes.INT_DTYPE)
138+
elif isinstance(op.labels, typing.Iterable):
139+
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
140+
else:
141+
left_adj: sge.Expression = (
142+
adj if this_bin == 0 and op.right else sge.convert(0)
143+
)
144+
right_adj: sge.Expression = (
145+
adj if this_bin == bins - 1 and not op.right else sge.convert(0)
146+
)
147+
148+
left: sge.Expression = (
149+
col_min + sge.convert(this_bin) * bin_width - left_adj
150+
)
151+
right: sge.Expression = (
152+
col_min + sge.convert(this_bin + 1) * bin_width + right_adj
153+
)
154+
if op.right:
155+
value = sge.Struct(
156+
expressions=[
157+
sge.PropertyEQ(
158+
this=sge.Identifier(this="left_exclusive", quoted=True),
159+
expression=left,
160+
),
161+
sge.PropertyEQ(
162+
this=sge.Identifier(this="right_inclusive", quoted=True),
163+
expression=right,
164+
),
165+
]
166+
)
167+
else:
168+
value = sge.Struct(
169+
expressions=[
170+
sge.PropertyEQ(
171+
this=sge.Identifier(this="left_inclusive", quoted=True),
172+
expression=left,
173+
),
174+
sge.PropertyEQ(
175+
this=sge.Identifier(this="right_exclusive", quoted=True),
176+
expression=right,
177+
),
178+
]
179+
)
180+
181+
condition: sge.Expression
182+
if this_bin == bins - 1:
183+
condition = sge.Is(this=column.expr, expression=sge.Not(this=sge.Null()))
184+
else:
185+
if op.right:
186+
condition = sge.LTE(
187+
this=column.expr,
188+
expression=(col_min + sge.convert(this_bin + 1) * bin_width),
189+
)
190+
else:
191+
condition = sge.LT(
192+
this=column.expr,
193+
expression=(col_min + sge.convert(this_bin + 1) * bin_width),
194+
)
195+
case_expr = case_expr.when(condition, value)
196+
return case_expr
197+
198+
199+
def _cut_ops_w_intervals(
200+
op: agg_ops.CutOp,
201+
column: typed_expr.TypedExpr,
202+
bins: typing.Iterable[typing.Tuple[typing.Any, typing.Any]],
203+
window: typing.Optional[window_spec.WindowSpec] = None,
204+
) -> sge.Case:
205+
case_expr = sge.Case()
206+
for this_bin, interval in enumerate(bins):
207+
left: sge.Expression = ir._literal(
208+
interval[0], dtypes.infer_literal_type(interval[0])
209+
)
210+
right: sge.Expression = ir._literal(
211+
interval[1], dtypes.infer_literal_type(interval[1])
212+
)
213+
condition: sge.Expression
214+
if op.right:
215+
condition = sge.And(
216+
this=sge.GT(this=column.expr, expression=left),
217+
expression=sge.LTE(this=column.expr, expression=right),
218+
)
219+
else:
220+
condition = sge.And(
221+
this=sge.GTE(this=column.expr, expression=left),
222+
expression=sge.LT(this=column.expr, expression=right),
223+
)
224+
225+
value: sge.Expression
226+
if op.labels is False:
227+
value = ir._literal(this_bin, dtypes.INT_DTYPE)
228+
elif isinstance(op.labels, typing.Iterable):
229+
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
230+
else:
231+
if op.right:
232+
value = sge.Struct(
233+
expressions=[
234+
sge.PropertyEQ(
235+
this=sge.Identifier(this="left_exclusive"), expression=left
236+
),
237+
sge.PropertyEQ(
238+
this=sge.Identifier(this="right_inclusive"),
239+
expression=right,
240+
),
241+
]
242+
)
243+
else:
244+
value = sge.Struct(
245+
expressions=[
246+
sge.PropertyEQ(
247+
this=sge.Identifier(this="left_inclusive"), expression=left
248+
),
249+
sge.PropertyEQ(
250+
this=sge.Identifier(this="right_exclusive"),
251+
expression=right,
252+
),
253+
]
254+
)
255+
case_expr = case_expr.when(condition, value)
256+
return case_expr
257+
258+
101259
@UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp)
102260
def _(
103261
op: agg_ops.DateSeriesDiffOp,
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
CASE
8+
WHEN `int64_col` <= MIN(`int64_col`) OVER () + (
9+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
10+
)
11+
THEN STRUCT(
12+
(
13+
MIN(`int64_col`) OVER () + (
14+
0 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
15+
)
16+
) - (
17+
(
18+
MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER ()
19+
) * 0.001
20+
) AS `left_exclusive`,
21+
MIN(`int64_col`) OVER () + (
22+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
23+
) + 0 AS `right_inclusive`
24+
)
25+
WHEN `int64_col` <= MIN(`int64_col`) OVER () + (
26+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
27+
)
28+
THEN STRUCT(
29+
(
30+
MIN(`int64_col`) OVER () + (
31+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
32+
)
33+
) - 0 AS `left_exclusive`,
34+
MIN(`int64_col`) OVER () + (
35+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
36+
) + 0 AS `right_inclusive`
37+
)
38+
WHEN `int64_col` IS NOT NULL
39+
THEN STRUCT(
40+
(
41+
MIN(`int64_col`) OVER () + (
42+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
43+
)
44+
) - 0 AS `left_exclusive`,
45+
MIN(`int64_col`) OVER () + (
46+
3 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
47+
) + 0 AS `right_inclusive`
48+
)
49+
END AS `bfcol_1`,
50+
CASE
51+
WHEN `int64_col` > 0 AND `int64_col` <= 1
52+
THEN STRUCT(0 AS left_exclusive, 1 AS right_inclusive)
53+
WHEN `int64_col` > 1 AND `int64_col` <= 2
54+
THEN STRUCT(1 AS left_exclusive, 2 AS right_inclusive)
55+
END AS `bfcol_2`,
56+
CASE
57+
WHEN `int64_col` < MIN(`int64_col`) OVER () + (
58+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
59+
)
60+
THEN 'a'
61+
WHEN `int64_col` < MIN(`int64_col`) OVER () + (
62+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
63+
)
64+
THEN 'b'
65+
WHEN `int64_col` IS NOT NULL
66+
THEN 'c'
67+
END AS `bfcol_3`,
68+
CASE
69+
WHEN `int64_col` > 0 AND `int64_col` <= 1
70+
THEN 0
71+
WHEN `int64_col` > 1 AND `int64_col` <= 2
72+
THEN 1
73+
END AS `bfcol_4`
74+
FROM `bfcte_0`
75+
)
76+
SELECT
77+
`bfcol_1` AS `int_bins`,
78+
`bfcol_2` AS `interval_bins`,
79+
`bfcol_3` AS `int_bins_labels`,
80+
`bfcol_4` AS `interval_bins_labels`
81+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,33 @@ def test_count(scalar_types_df: bpd.DataFrame, snapshot):
160160
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")
161161

162162

163+
def test_cut(scalar_types_df: bpd.DataFrame, snapshot):
164+
col_name = "int64_col"
165+
bf_df = scalar_types_df[[col_name]]
166+
agg_ops_map = {
167+
"int_bins": agg_exprs.UnaryAggregation(
168+
agg_ops.CutOp(bins=3, right=True, labels=None), expression.deref(col_name)
169+
),
170+
"interval_bins": agg_exprs.UnaryAggregation(
171+
agg_ops.CutOp(bins=((0, 1), (1, 2)), right=True, labels=None),
172+
expression.deref(col_name),
173+
),
174+
"int_bins_labels": agg_exprs.UnaryAggregation(
175+
agg_ops.CutOp(bins=3, labels=("a", "b", "c"), right=False),
176+
expression.deref(col_name),
177+
),
178+
"interval_bins_labels": agg_exprs.UnaryAggregation(
179+
agg_ops.CutOp(bins=((0, 1), (1, 2)), labels=False, right=True),
180+
expression.deref(col_name),
181+
),
182+
}
183+
sql = _apply_unary_agg_ops(
184+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
185+
)
186+
187+
snapshot.assert_match(sql, "out.sql")
188+
189+
163190
def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
164191
col_name = "int64_col"
165192
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)