Skip to content

Commit 09ee218

Browse files
committed
Merge branch 'main' into shuowei-anywidget-html-repr
2 parents f47f87b + 33a211e commit 09ee218

File tree

23 files changed

+566
-191
lines changed

23 files changed

+566
-191
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -245,27 +245,6 @@ def _cut_ops_w_intervals(
245245
return case_expr
246246

247247

248-
@UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp)
249-
def _(
250-
op: agg_ops.DateSeriesDiffOp,
251-
column: typed_expr.TypedExpr,
252-
window: typing.Optional[window_spec.WindowSpec] = None,
253-
) -> sge.Expression:
254-
if column.dtype != dtypes.DATE_DTYPE:
255-
raise TypeError(f"Cannot perform date series diff on type {column.dtype}")
256-
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
257-
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
258-
# Conversion factor from days to microseconds
259-
conversion_factor = 24 * 60 * 60 * 1_000_000
260-
return sge.Cast(
261-
this=sge.DateDiff(
262-
this=column.expr, expression=shifted, unit=sge.Identifier(this="DAY")
263-
)
264-
* sge.convert(conversion_factor),
265-
to="INT64",
266-
)
267-
268-
269248
@UNARY_OP_REGISTRATION.register(agg_ops.DenseRankOp)
270249
def _(
271250
op: agg_ops.DenseRankOp,
@@ -327,13 +306,27 @@ def _(
327306
) -> sge.Expression:
328307
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
329308
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
330-
if column.dtype in (dtypes.BOOL_DTYPE, dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE):
331-
if column.dtype == dtypes.BOOL_DTYPE:
332-
return sge.NEQ(this=column.expr, expression=shifted)
333-
else:
334-
return sge.Sub(this=column.expr, expression=shifted)
335-
else:
336-
raise TypeError(f"Cannot perform diff on type {column.dtype}")
309+
if column.dtype == dtypes.BOOL_DTYPE:
310+
return sge.NEQ(this=column.expr, expression=shifted)
311+
312+
if column.dtype in (dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE):
313+
return sge.Sub(this=column.expr, expression=shifted)
314+
315+
if column.dtype == dtypes.TIMESTAMP_DTYPE:
316+
return sge.TimestampDiff(
317+
this=column.expr,
318+
expression=shifted,
319+
unit=sge.Identifier(this="MICROSECOND"),
320+
)
321+
322+
if column.dtype == dtypes.DATETIME_DTYPE:
323+
return sge.DatetimeDiff(
324+
this=column.expr,
325+
expression=shifted,
326+
unit=sge.Identifier(this="MICROSECOND"),
327+
)
328+
329+
raise TypeError(f"Cannot perform diff on type {column.dtype}")
337330

338331

339332
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
@@ -593,23 +586,6 @@ def _(
593586
return sge.func("IFNULL", expr, ir._literal(zero, column.dtype))
594587

595588

596-
@UNARY_OP_REGISTRATION.register(agg_ops.TimeSeriesDiffOp)
597-
def _(
598-
op: agg_ops.TimeSeriesDiffOp,
599-
column: typed_expr.TypedExpr,
600-
window: typing.Optional[window_spec.WindowSpec] = None,
601-
) -> sge.Expression:
602-
if column.dtype != dtypes.TIMESTAMP_DTYPE:
603-
raise TypeError(f"Cannot perform time series diff on type {column.dtype}")
604-
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
605-
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
606-
return sge.TimestampDiff(
607-
this=column.expr,
608-
expression=shifted,
609-
unit=sge.Identifier(this="MICROSECOND"),
610-
)
611-
612-
613589
@UNARY_OP_REGISTRATION.register(agg_ops.VarOp)
614590
def _(
615591
op: agg_ops.VarOp,

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def apply_window_if_present(
6262

6363
# This is the key change. Don't create a spec for the default window frame
6464
# if there's no ordering. This avoids generating an `ORDER BY NULL` clause.
65-
if not window.bounds and not order:
65+
if window.is_unbounded and not order:
6666
return sge.Window(this=value, partition_by=group_by)
6767

68-
if not window.bounds and not include_framing_clauses:
68+
if window.is_unbounded and not include_framing_clauses:
6969
return sge.Window(this=value, partition_by=group_by, order=order)
7070

7171
kind = (

bigframes/core/compile/sqlglot/expressions/datetime_ops.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
import sqlglot.expressions as sge
1818

19+
from bigframes import dtypes
1920
from bigframes import operations as ops
21+
from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS
2022
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2123
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2224

@@ -81,22 +83,73 @@ def _(expr: TypedExpr) -> sge.Expression:
8183

8284
@register_unary_op(ops.StrftimeOp, pass_op=True)
8385
def _(expr: TypedExpr, op: ops.StrftimeOp) -> sge.Expression:
84-
return sge.func("FORMAT_TIMESTAMP", sge.convert(op.date_format), expr.expr)
86+
func_name = ""
87+
if expr.dtype == dtypes.DATE_DTYPE:
88+
func_name = "FORMAT_DATE"
89+
elif expr.dtype == dtypes.DATETIME_DTYPE:
90+
func_name = "FORMAT_DATETIME"
91+
elif expr.dtype == dtypes.TIME_DTYPE:
92+
func_name = "FORMAT_TIME"
93+
elif expr.dtype == dtypes.TIMESTAMP_DTYPE:
94+
func_name = "FORMAT_TIMESTAMP"
95+
96+
return sge.func(func_name, sge.convert(op.date_format), expr.expr)
8597

8698

8799
@register_unary_op(ops.time_op)
88100
def _(expr: TypedExpr) -> sge.Expression:
89101
return sge.func("TIME", expr.expr)
90102

91103

92-
@register_unary_op(ops.ToDatetimeOp)
93-
def _(expr: TypedExpr) -> sge.Expression:
94-
return sge.Cast(this=sge.func("TIMESTAMP_SECONDS", expr.expr), to="DATETIME")
95-
104+
@register_unary_op(ops.ToDatetimeOp, pass_op=True)
105+
def _(expr: TypedExpr, op: ops.ToDatetimeOp) -> sge.Expression:
106+
if op.format:
107+
result = expr.expr
108+
if expr.dtype != dtypes.STRING_DTYPE:
109+
result = sge.Cast(this=result, to="STRING")
110+
result = sge.func(
111+
"PARSE_TIMESTAMP", sge.convert(op.format), result, sge.convert("UTC")
112+
)
113+
return sge.Cast(this=result, to="DATETIME")
114+
115+
if expr.dtype == dtypes.STRING_DTYPE:
116+
return sge.TryCast(this=expr.expr, to="DATETIME")
117+
118+
value = expr.expr
119+
unit = op.unit or "ns"
120+
factor = UNIT_TO_US_CONVERSION_FACTORS[unit]
121+
if factor != 1:
122+
value = sge.Mul(this=value, expression=sge.convert(factor))
123+
value = sge.func("TRUNC", value)
124+
return sge.Cast(
125+
this=sge.func("TIMESTAMP_MICROS", sge.Cast(this=value, to="INT64")),
126+
to="DATETIME",
127+
)
128+
129+
130+
@register_unary_op(ops.ToTimestampOp, pass_op=True)
131+
def _(expr: TypedExpr, op: ops.ToTimestampOp) -> sge.Expression:
132+
if op.format:
133+
result = expr.expr
134+
if expr.dtype != dtypes.STRING_DTYPE:
135+
result = sge.Cast(this=result, to="STRING")
136+
return sge.func(
137+
"PARSE_TIMESTAMP", sge.convert(op.format), expr.expr, sge.convert("UTC")
138+
)
96139

97-
@register_unary_op(ops.ToTimestampOp)
98-
def _(expr: TypedExpr) -> sge.Expression:
99-
return sge.func("TIMESTAMP_SECONDS", expr.expr)
140+
if expr.dtype == dtypes.STRING_DTYPE:
141+
return sge.func("TIMESTAMP", expr.expr)
142+
143+
value = expr.expr
144+
unit = op.unit or "ns"
145+
factor = UNIT_TO_US_CONVERSION_FACTORS[unit]
146+
if factor != 1:
147+
value = sge.Mul(this=value, expression=sge.convert(factor))
148+
value = sge.func("TRUNC", value)
149+
return sge.Cast(
150+
this=sge.func("TIMESTAMP_MICROS", sge.Cast(this=value, to="INT64")),
151+
to="TIMESTAMP",
152+
)
100153

101154

102155
@register_unary_op(ops.UnixMicros)

bigframes/core/compile/sqlglot/expressions/timedelta_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import sqlglot.expressions as sge
1818

19+
from bigframes import dtypes
1920
from bigframes import operations as ops
2021
from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS
2122
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
@@ -32,7 +33,12 @@ def _(expr: TypedExpr) -> sge.Expression:
3233
@register_unary_op(ops.ToTimedeltaOp, pass_op=True)
3334
def _(expr: TypedExpr, op: ops.ToTimedeltaOp) -> sge.Expression:
3435
value = expr.expr
36+
if expr.dtype == dtypes.TIMEDELTA_DTYPE:
37+
return value
38+
3539
factor = UNIT_TO_US_CONVERSION_FACTORS[op.unit]
3640
if factor != 1:
3741
value = sge.Mul(this=value, expression=sge.convert(factor))
42+
if expr.dtype == dtypes.FLOAT_DTYPE:
43+
value = sge.Cast(this=sge.Floor(this=value), to=sge.DataType(this="INT64"))
3844
return value

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
648648
elif dtype == dtypes.BYTES_DTYPE:
649649
return _cast(str(value), sqlglot_type)
650650
elif dtypes.is_time_like(dtype):
651+
if isinstance(value, str):
652+
return _cast(sge.convert(value), sqlglot_type)
651653
if isinstance(value, np.generic):
652654
value = value.item()
653655
return _cast(sge.convert(value.isoformat()), sqlglot_type)

bigframes/core/local_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,13 @@ def to_arrow(
124124
geo_format: Literal["wkb", "wkt"] = "wkt",
125125
duration_type: Literal["int", "duration"] = "duration",
126126
json_type: Literal["string"] = "string",
127+
max_chunksize: Optional[int] = None,
127128
) -> tuple[pa.Schema, Iterable[pa.RecordBatch]]:
128129
if geo_format != "wkt":
129130
raise NotImplementedError(f"geo format {geo_format} not yet implemented")
130131
assert json_type == "string"
131132

132-
batches = self.data.to_batches()
133+
batches = self.data.to_batches(max_chunksize=max_chunksize)
133134
schema = self.data.schema
134135
if duration_type == "int":
135136
schema = _schema_durations_to_ints(schema)

0 commit comments

Comments
 (0)