Skip to content

Commit 9f68227

Browse files
committed
Merge branch 'main' into shuowei-anywidget-html-repr
2 parents 12335b3 + bb66915 commit 9f68227

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+484
-176
lines changed

bigframes/core/array_value.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -401,37 +401,10 @@ def aggregate(
401401
)
402402
)
403403

404-
def project_window_op(
405-
self,
406-
column_name: str,
407-
op: agg_ops.UnaryWindowOp,
408-
window_spec: WindowSpec,
409-
*,
410-
never_skip_nulls=False,
411-
skip_reproject_unsafe: bool = False,
412-
) -> Tuple[ArrayValue, str]:
413-
"""
414-
Creates a new expression based on this expression with unary operation applied to one column.
415-
column_name: the id of the input column present in the expression
416-
op: the windowable operator to apply to the input column
417-
window_spec: a specification of the window over which to apply the operator
418-
output_name: the id to assign to the output of the operator, by default will replace input col if distinct output id not provided
419-
never_skip_nulls: will disable null skipping for operators that would otherwise do so
420-
skip_reproject_unsafe: skips the reprojection step, can be used when performing many non-dependent window operations, user responsible for not nesting window expressions, or using outputs as join, filter or aggregation keys before a reprojection
421-
"""
422-
423-
return self.project_window_expr(
424-
agg_expressions.UnaryAggregation(op, ex.deref(column_name)),
425-
window_spec,
426-
never_skip_nulls,
427-
skip_reproject_unsafe,
428-
)
429-
430404
def project_window_expr(
431405
self,
432406
expression: agg_expressions.Aggregation,
433407
window: WindowSpec,
434-
never_skip_nulls=False,
435408
skip_reproject_unsafe: bool = False,
436409
):
437410
output_name = self._gen_namespaced_uid()
@@ -442,7 +415,6 @@ def project_window_expr(
442415
expression=expression,
443416
window_spec=window,
444417
output_name=ids.ColumnId(output_name),
445-
never_skip_nulls=never_skip_nulls,
446418
skip_reproject_unsafe=skip_reproject_unsafe,
447419
)
448420
),

bigframes/core/blocks.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,6 @@ def multi_apply_window_op(
10901090
window_spec: windows.WindowSpec,
10911091
*,
10921092
skip_null_groups: bool = False,
1093-
never_skip_nulls: bool = False,
10941093
) -> typing.Tuple[Block, typing.Sequence[str]]:
10951094
block = self
10961095
result_ids = []
@@ -1103,7 +1102,6 @@ def multi_apply_window_op(
11031102
skip_reproject_unsafe=(i + 1) < len(columns),
11041103
result_label=label,
11051104
skip_null_groups=skip_null_groups,
1106-
never_skip_nulls=never_skip_nulls,
11071105
)
11081106
result_ids.append(result_id)
11091107
return block, result_ids
@@ -1184,15 +1182,13 @@ def apply_window_op(
11841182
result_label: Label = None,
11851183
skip_null_groups: bool = False,
11861184
skip_reproject_unsafe: bool = False,
1187-
never_skip_nulls: bool = False,
11881185
) -> typing.Tuple[Block, str]:
11891186
agg_expr = agg_expressions.UnaryAggregation(op, ex.deref(column))
11901187
return self.apply_analytic(
11911188
agg_expr,
11921189
window_spec,
11931190
result_label,
11941191
skip_reproject_unsafe=skip_reproject_unsafe,
1195-
never_skip_nulls=never_skip_nulls,
11961192
skip_null_groups=skip_null_groups,
11971193
)
11981194

@@ -1203,7 +1199,6 @@ def apply_analytic(
12031199
result_label: Label,
12041200
*,
12051201
skip_reproject_unsafe: bool = False,
1206-
never_skip_nulls: bool = False,
12071202
skip_null_groups: bool = False,
12081203
) -> typing.Tuple[Block, str]:
12091204
block = self
@@ -1214,7 +1209,6 @@ def apply_analytic(
12141209
agg_expr,
12151210
window,
12161211
skip_reproject_unsafe=skip_reproject_unsafe,
1217-
never_skip_nulls=never_skip_nulls,
12181212
)
12191213
block = Block(
12201214
expr,

bigframes/core/compile/compiled.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -394,16 +394,13 @@ def project_window_op(
394394
expression: ex_types.Aggregation,
395395
window_spec: WindowSpec,
396396
output_name: str,
397-
*,
398-
never_skip_nulls=False,
399397
) -> UnorderedIR:
400398
"""
401399
Creates a new expression based on this expression with unary operation applied to one column.
402400
column_name: the id of the input column present in the expression
403401
op: the windowable operator to apply to the input column
404402
window_spec: a specification of the window over which to apply the operator
405403
output_name: the id to assign to the output of the operator
406-
never_skip_nulls: will disable null skipping for operators that would otherwise do so
407404
"""
408405
# Cannot nest analytic expressions, so reproject to cte first if needed.
409406
# Also ibis cannot window literals, so need to reproject those (even though this is legal in googlesql)
@@ -425,7 +422,6 @@ def project_window_op(
425422
expression,
426423
window_spec,
427424
output_name,
428-
never_skip_nulls=never_skip_nulls,
429425
)
430426

431427
if expression.op.order_independent and window_spec.is_unbounded:
@@ -437,9 +433,6 @@ def project_window_op(
437433
expression, window_spec
438434
)
439435
clauses: list[tuple[ex.Expression, ex.Expression]] = []
440-
if expression.op.skips_nulls and not never_skip_nulls:
441-
for input in expression.inputs:
442-
clauses.append((ops.isnull_op.as_expr(input), ex.const(None)))
443436
if window_spec.min_periods and len(expression.inputs) > 0:
444437
if not expression.op.nulls_count_for_min_values:
445438
is_observation = ops.notnull_op.as_expr()

bigframes/core/compile/ibis_compiler/ibis_compiler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,6 @@ def compile_window(node: nodes.WindowOpNode, child: compiled.UnorderedIR):
269269
node.expression,
270270
node.window_spec,
271271
node.output_name.sql,
272-
never_skip_nulls=node.never_skip_nulls,
273272
)
274273
return result
275274

bigframes/core/compile/polars/compiler.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import dataclasses
1717
import functools
1818
import itertools
19-
import operator
2019
from typing import cast, Literal, Optional, Sequence, Tuple, Type, TYPE_CHECKING
2120

2221
import pandas as pd
@@ -868,26 +867,6 @@ def compile_window(self, node: nodes.WindowOpNode):
868867
df, node.expression, node.window_spec, node.output_name.sql
869868
)
870869
result = pl.concat([df, window_result], how="horizontal")
871-
872-
# Probably easier just to pull this out as a rewriter
873-
if (
874-
node.expression.op.skips_nulls
875-
and not node.never_skip_nulls
876-
and node.expression.column_references
877-
):
878-
nullity_expr = functools.reduce(
879-
operator.or_,
880-
(
881-
pl.col(column.sql).is_null()
882-
for column in node.expression.column_references
883-
),
884-
)
885-
result = result.with_columns(
886-
pl.when(nullity_expr)
887-
.then(None)
888-
.otherwise(pl.col(node.output_name.sql))
889-
.alias(node.output_name.sql)
890-
)
891870
return result
892871

893872
def _calc_row_analytic_func(

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,140 @@ def _(
111111
return apply_window_if_present(sge.func("COUNT", column.expr), window)
112112

113113

114+
@UNARY_OP_REGISTRATION.register(agg_ops.CutOp)
115+
def _(
116+
op: agg_ops.CutOp,
117+
column: typed_expr.TypedExpr,
118+
window: typing.Optional[window_spec.WindowSpec] = None,
119+
) -> sge.Expression:
120+
if isinstance(op.bins, int):
121+
case_expr = _cut_ops_w_int_bins(op, column, op.bins, window)
122+
else: # Interpret as intervals
123+
case_expr = _cut_ops_w_intervals(op, column, op.bins, window)
124+
return case_expr
125+
126+
127+
def _cut_ops_w_int_bins(
128+
op: agg_ops.CutOp,
129+
column: typed_expr.TypedExpr,
130+
bins: int,
131+
window: typing.Optional[window_spec.WindowSpec] = None,
132+
) -> sge.Case:
133+
case_expr = sge.Case()
134+
col_min = apply_window_if_present(
135+
sge.func("MIN", column.expr), window or window_spec.WindowSpec()
136+
)
137+
col_max = apply_window_if_present(
138+
sge.func("MAX", column.expr), window or window_spec.WindowSpec()
139+
)
140+
adj: sge.Expression = sge.Sub(this=col_max, expression=col_min) * sge.convert(0.001)
141+
bin_width: sge.Expression = sge.func(
142+
"IEEE_DIVIDE",
143+
sge.Sub(this=col_max, expression=col_min),
144+
sge.convert(bins),
145+
)
146+
147+
for this_bin in range(bins):
148+
value: sge.Expression
149+
if op.labels is False:
150+
value = ir._literal(this_bin, dtypes.INT_DTYPE)
151+
elif isinstance(op.labels, typing.Iterable):
152+
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
153+
else:
154+
left_adj: sge.Expression = (
155+
adj if this_bin == 0 and op.right else sge.convert(0)
156+
)
157+
right_adj: sge.Expression = (
158+
adj if this_bin == bins - 1 and not op.right else sge.convert(0)
159+
)
160+
161+
left: sge.Expression = (
162+
col_min + sge.convert(this_bin) * bin_width - left_adj
163+
)
164+
right: sge.Expression = (
165+
col_min + sge.convert(this_bin + 1) * bin_width + right_adj
166+
)
167+
if op.right:
168+
left_identifier = sge.Identifier(this="left_exclusive", quoted=True)
169+
right_identifier = sge.Identifier(this="right_inclusive", quoted=True)
170+
else:
171+
left_identifier = sge.Identifier(this="left_inclusive", quoted=True)
172+
right_identifier = sge.Identifier(this="right_exclusive", quoted=True)
173+
174+
value = sge.Struct(
175+
expressions=[
176+
sge.PropertyEQ(this=left_identifier, expression=left),
177+
sge.PropertyEQ(this=right_identifier, expression=right),
178+
]
179+
)
180+
181+
condition: sge.Expression
182+
if this_bin == bins - 1:
183+
condition = sge.Is(this=column.expr, expression=sge.Not(this=sge.Null()))
184+
else:
185+
if op.right:
186+
condition = sge.LTE(
187+
this=column.expr,
188+
expression=(col_min + sge.convert(this_bin + 1) * bin_width),
189+
)
190+
else:
191+
condition = sge.LT(
192+
this=column.expr,
193+
expression=(col_min + sge.convert(this_bin + 1) * bin_width),
194+
)
195+
case_expr = case_expr.when(condition, value)
196+
return case_expr
197+
198+
199+
def _cut_ops_w_intervals(
200+
op: agg_ops.CutOp,
201+
column: typed_expr.TypedExpr,
202+
bins: typing.Iterable[typing.Tuple[typing.Any, typing.Any]],
203+
window: typing.Optional[window_spec.WindowSpec] = None,
204+
) -> sge.Case:
205+
case_expr = sge.Case()
206+
for this_bin, interval in enumerate(bins):
207+
left: sge.Expression = ir._literal(
208+
interval[0], dtypes.infer_literal_type(interval[0])
209+
)
210+
right: sge.Expression = ir._literal(
211+
interval[1], dtypes.infer_literal_type(interval[1])
212+
)
213+
condition: sge.Expression
214+
if op.right:
215+
condition = sge.And(
216+
this=sge.GT(this=column.expr, expression=left),
217+
expression=sge.LTE(this=column.expr, expression=right),
218+
)
219+
else:
220+
condition = sge.And(
221+
this=sge.GTE(this=column.expr, expression=left),
222+
expression=sge.LT(this=column.expr, expression=right),
223+
)
224+
225+
value: sge.Expression
226+
if op.labels is False:
227+
value = ir._literal(this_bin, dtypes.INT_DTYPE)
228+
elif isinstance(op.labels, typing.Iterable):
229+
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
230+
else:
231+
if op.right:
232+
left_identifier = sge.Identifier(this="left_exclusive", quoted=True)
233+
right_identifier = sge.Identifier(this="right_inclusive", quoted=True)
234+
else:
235+
left_identifier = sge.Identifier(this="left_inclusive", quoted=True)
236+
right_identifier = sge.Identifier(this="right_exclusive", quoted=True)
237+
238+
value = sge.Struct(
239+
expressions=[
240+
sge.PropertyEQ(this=left_identifier, expression=left),
241+
sge.PropertyEQ(this=right_identifier, expression=right),
242+
]
243+
)
244+
case_expr = case_expr.when(condition, value)
245+
return case_expr
246+
247+
114248
@UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp)
115249
def _(
116250
op: agg_ops.DateSeriesDiffOp,

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,8 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI
324324
)
325325

326326
clauses: list[tuple[sge.Expression, sge.Expression]] = []
327-
if node.expression.op.skips_nulls and not node.never_skip_nulls:
328-
for column in inputs:
329-
clauses.append((sge.Is(this=column, expression=sge.Null()), sge.Null()))
330-
331327
if window_spec.min_periods and len(inputs) > 0:
332-
if node.expression.op.skips_nulls:
328+
if not node.expression.op.nulls_count_for_min_values:
333329
# Most operations do not count NULL values towards min_periods
334330
not_null_columns = [
335331
sge.Not(this=sge.Is(this=column, expression=sge.Null()))

bigframes/core/compile/sqlglot/expressions/geo_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ def _(expr: TypedExpr) -> sge.Expression:
116116
return sge.func("SAFE.ST_Y", expr.expr)
117117

118118

119+
@register_binary_op(ops.GeoStDistanceOp, pass_op=True)
120+
def _(left: TypedExpr, right: TypedExpr, op: ops.GeoStDistanceOp) -> sge.Expression:
121+
return sge.func("ST_DISTANCE", left.expr, right.expr, sge.convert(op.use_spheroid))
122+
123+
119124
@register_binary_op(ops.geo_st_difference_op)
120125
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
121126
return sge.func("ST_DIFFERENCE", left.expr, right.expr)

0 commit comments

Comments
 (0)