Skip to content

Commit 8e4ea88

Browse files
committed
Merge branch 'main' into shuowei-anywidget-col
2 parents 6801ca4 + 4c98c95 commit 8e4ea88

File tree

25 files changed

+712
-107
lines changed

25 files changed

+712
-107
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,26 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
328328
assert isinstance(op, string_ops.StrContainsRegexOp)
329329
return input.str.contains(pattern=op.pat, literal=False)
330330

331+
@compile_op.register(string_ops.UpperOp)
332+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
333+
assert isinstance(op, string_ops.UpperOp)
334+
return input.str.to_uppercase()
335+
336+
@compile_op.register(string_ops.LowerOp)
337+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
338+
assert isinstance(op, string_ops.LowerOp)
339+
return input.str.to_lowercase()
340+
341+
@compile_op.register(string_ops.ArrayLenOp)
342+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
343+
assert isinstance(op, string_ops.ArrayLenOp)
344+
return input.list.len()
345+
346+
@compile_op.register(string_ops.StrLenOp)
347+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
348+
assert isinstance(op, string_ops.StrLenOp)
349+
return input.str.len_chars()
350+
331351
@compile_op.register(string_ops.StartsWithOp)
332352
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
333353
assert isinstance(op, string_ops.StartsWithOp)

bigframes/core/compile/polars/lowering.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
generic_ops,
2828
json_ops,
2929
numeric_ops,
30+
string_ops,
3031
)
3132
import bigframes.operations as ops
3233

@@ -347,11 +348,28 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
347348
return ops.coalesce_op.as_expr(new_isin, expression.const(False))
348349

349350

351+
class LowerLenOp(op_lowering.OpLoweringRule):
352+
@property
353+
def op(self) -> type[ops.ScalarOp]:
354+
return string_ops.LenOp
355+
356+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
357+
assert isinstance(expr.op, string_ops.LenOp)
358+
arg = expr.children[0]
359+
360+
if dtypes.is_string_like(arg.output_type):
361+
return string_ops.StrLenOp().as_expr(arg)
362+
elif dtypes.is_array_like(arg.output_type):
363+
return string_ops.ArrayLenOp().as_expr(arg)
364+
else:
365+
raise ValueError(f"Unexpected type: {arg.output_type}")
366+
367+
350368
def _coerce_comparables(
351369
expr1: expression.Expression,
352370
expr2: expression.Expression,
353371
*,
354-
bools_only: bool = False
372+
bools_only: bool = False,
355373
):
356374
if bools_only:
357375
if (
@@ -446,6 +464,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
446464
LowerAsTypeRule(),
447465
LowerInvertOp(),
448466
LowerIsinOp(),
467+
LowerLenOp(),
449468
)
450469

451470

bigframes/core/compile/polars/operations/numeric_ops.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,6 @@
2929
import polars as pl
3030

3131

32-
@polars_compiler.register_op(numeric_ops.CosOp)
33-
def cos_op_impl(
34-
compiler: polars_compiler.PolarsExpressionCompiler,
35-
op: numeric_ops.CosOp, # type: ignore
36-
input: pl.Expr,
37-
) -> pl.Expr:
38-
return input.cos()
39-
40-
4132
@polars_compiler.register_op(numeric_ops.LnOp)
4233
def ln_op_impl(
4334
compiler: polars_compiler.PolarsExpressionCompiler,
@@ -80,6 +71,78 @@ def sin_op_impl(
8071
return input.sin()
8172

8273

74+
@polars_compiler.register_op(numeric_ops.CosOp)
75+
def cos_op_impl(
76+
compiler: polars_compiler.PolarsExpressionCompiler,
77+
op: numeric_ops.CosOp, # type: ignore
78+
input: pl.Expr,
79+
) -> pl.Expr:
80+
return input.cos()
81+
82+
83+
@polars_compiler.register_op(numeric_ops.TanOp)
84+
def tan_op_impl(
85+
compiler: polars_compiler.PolarsExpressionCompiler,
86+
op: numeric_ops.SinOp, # type: ignore
87+
input: pl.Expr,
88+
) -> pl.Expr:
89+
return input.tan()
90+
91+
92+
@polars_compiler.register_op(numeric_ops.SinhOp)
93+
def sinh_op_impl(
94+
compiler: polars_compiler.PolarsExpressionCompiler,
95+
op: numeric_ops.SinOp, # type: ignore
96+
input: pl.Expr,
97+
) -> pl.Expr:
98+
return input.sinh()
99+
100+
101+
@polars_compiler.register_op(numeric_ops.CoshOp)
102+
def cosh_op_impl(
103+
compiler: polars_compiler.PolarsExpressionCompiler,
104+
op: numeric_ops.CosOp, # type: ignore
105+
input: pl.Expr,
106+
) -> pl.Expr:
107+
return input.cosh()
108+
109+
110+
@polars_compiler.register_op(numeric_ops.TanhOp)
111+
def tanh_op_impl(
112+
compiler: polars_compiler.PolarsExpressionCompiler,
113+
op: numeric_ops.SinOp, # type: ignore
114+
input: pl.Expr,
115+
) -> pl.Expr:
116+
return input.tanh()
117+
118+
119+
@polars_compiler.register_op(numeric_ops.ArcsinOp)
120+
def asin_op_impl(
121+
compiler: polars_compiler.PolarsExpressionCompiler,
122+
op: numeric_ops.ArcsinOp, # type: ignore
123+
input: pl.Expr,
124+
) -> pl.Expr:
125+
return input.arcsin()
126+
127+
128+
@polars_compiler.register_op(numeric_ops.ArccosOp)
129+
def acos_op_impl(
130+
compiler: polars_compiler.PolarsExpressionCompiler,
131+
op: numeric_ops.ArccosOp, # type: ignore
132+
input: pl.Expr,
133+
) -> pl.Expr:
134+
return input.arccos()
135+
136+
137+
@polars_compiler.register_op(numeric_ops.ArctanOp)
138+
def atan_op_impl(
139+
compiler: polars_compiler.PolarsExpressionCompiler,
140+
op: numeric_ops.ArctanOp, # type: ignore
141+
input: pl.Expr,
142+
) -> pl.Expr:
143+
return input.arctan()
144+
145+
83146
@polars_compiler.register_op(numeric_ops.SqrtOp)
84147
def sqrt_op_impl(
85148
compiler: polars_compiler.PolarsExpressionCompiler,

bigframes/core/compile/sqlglot/expressions/date_ops.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@ def _(expr: TypedExpr) -> sge.Expression:
3535

3636
@register_unary_op(ops.dayofweek_op)
3737
def _(expr: TypedExpr) -> sge.Expression:
38-
# Adjust the 1-based day-of-week index (from SQL) to a 0-based index.
39-
return sge.Extract(
40-
this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr
41-
) - sge.convert(1)
38+
return dayofweek_op_impl(expr)
4239

4340

4441
@register_unary_op(ops.dayofyear_op)
@@ -48,7 +45,8 @@ def _(expr: TypedExpr) -> sge.Expression:
4845

4946
@register_unary_op(ops.iso_day_op)
5047
def _(expr: TypedExpr) -> sge.Expression:
51-
return sge.Extract(this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr)
48+
# Plus 1 because iso day of week uses 1-based indexing
49+
return dayofweek_op_impl(expr) + sge.convert(1)
5250

5351

5452
@register_unary_op(ops.iso_week_op)
@@ -59,3 +57,16 @@ def _(expr: TypedExpr) -> sge.Expression:
5957
@register_unary_op(ops.iso_year_op)
6058
def _(expr: TypedExpr) -> sge.Expression:
6159
return sge.Extract(this=sge.Identifier(this="ISOYEAR"), expression=expr.expr)
60+
61+
62+
# Helpers
63+
def dayofweek_op_impl(expr: TypedExpr) -> sge.Expression:
64+
# BigQuery SQL Extract(DAYOFWEEK) returns 1 for Sunday through 7 for Saturday.
65+
# We want 0 for Monday through 6 for Sunday to be compatible with Pandas.
66+
extract_expr = sge.Extract(
67+
this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr
68+
)
69+
return sge.Cast(
70+
this=sge.Mod(this=extract_expr + sge.convert(5), expression=sge.convert(7)),
71+
to="INT64",
72+
)

bigframes/core/compile/sqlglot/expressions/datetime_ops.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,28 @@
2525

2626
@register_unary_op(ops.FloorDtOp, pass_op=True)
2727
def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression:
28-
# TODO: Remove this method when it is covered by ops.FloorOp
29-
return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=op.freq))
28+
pandas_to_bq_freq_map = {
29+
"Y": "YEAR",
30+
"Q": "QUARTER",
31+
"M": "MONTH",
32+
"W": "WEEK(MONDAY)",
33+
"D": "DAY",
34+
"h": "HOUR",
35+
"min": "MINUTE",
36+
"s": "SECOND",
37+
"ms": "MILLISECOND",
38+
"us": "MICROSECOND",
39+
"ns": "NANOSECOND",
40+
}
41+
if op.freq not in pandas_to_bq_freq_map.keys():
42+
raise NotImplementedError(
43+
f"Unsupported freq paramater: {op.freq}"
44+
+ " Supported freq parameters are: "
45+
+ ",".join(pandas_to_bq_freq_map.keys())
46+
)
47+
48+
bq_freq = pandas_to_bq_freq_map[op.freq]
49+
return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq))
3050

3151

3252
@register_unary_op(ops.hour_op)

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import sqlglot as sg
1718
import sqlglot.expressions as sge
1819

1920
from bigframes import dtypes
@@ -80,6 +81,16 @@ def _(expr: TypedExpr) -> sge.Expression:
8081
return sge.BitwiseNot(this=sge.paren(expr.expr))
8182

8283

84+
@register_nary_op(ops.SqlScalarOp, pass_op=True)
85+
def _(*operands: TypedExpr, op: ops.SqlScalarOp) -> sge.Expression:
86+
return sg.parse_one(
87+
op.sql_template.format(
88+
*[operand.expr.sql(dialect="bigquery") for operand in operands]
89+
),
90+
dialect="bigquery",
91+
)
92+
93+
8394
@register_unary_op(ops.isnull_op)
8495
def _(expr: TypedExpr) -> sge.Expression:
8596
return sge.Is(this=expr.expr, expression=sge.Null())

bigframes/core/compile/sqlglot/expressions/struct_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2525
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2626

27+
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
2728
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
2829

2930

@@ -40,3 +41,13 @@ def _(expr: TypedExpr, op: ops.StructFieldOp) -> sge.Expression:
4041
this=sge.to_identifier(name, quoted=True),
4142
catalog=expr.expr,
4243
)
44+
45+
46+
@register_nary_op(ops.StructOp, pass_op=True)
47+
def _(*exprs: TypedExpr, op: ops.StructOp) -> sge.Struct:
48+
return sge.Struct(
49+
expressions=[
50+
sge.PropertyEQ(this=sge.to_identifier(col), expression=expr.expr)
51+
for col, expr in zip(op.column_names, exprs)
52+
]
53+
)

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,37 @@ class ScalarOpCompiler:
3131
typing.Callable[[typing.Sequence[TypedExpr], ops.RowOp], sge.Expression],
3232
] = {}
3333

34+
# A set of SQLGlot classes that may need to be parenthesized
35+
SQLGLOT_NEEDS_PARENS = {
36+
# Numeric operations
37+
sge.Add,
38+
sge.Sub,
39+
sge.Mul,
40+
sge.Div,
41+
sge.Mod,
42+
sge.Pow,
43+
# Comparison operations
44+
sge.GTE,
45+
sge.GT,
46+
sge.LTE,
47+
sge.LT,
48+
sge.EQ,
49+
sge.NEQ,
50+
# Logical operations
51+
sge.And,
52+
sge.Or,
53+
sge.Xor,
54+
# Bitwise operations
55+
sge.BitwiseAnd,
56+
sge.BitwiseOr,
57+
sge.BitwiseXor,
58+
sge.BitwiseLeftShift,
59+
sge.BitwiseRightShift,
60+
sge.BitwiseNot,
61+
# Other operations
62+
sge.Is,
63+
}
64+
3465
@functools.singledispatchmethod
3566
def compile_expression(
3667
self,
@@ -110,10 +141,12 @@ def register_binary_op(
110141

111142
def decorator(impl: typing.Callable[..., sge.Expression]):
112143
def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp):
144+
left = self._add_parentheses(args[0])
145+
right = self._add_parentheses(args[1])
113146
if pass_op:
114-
return impl(args[0], args[1], op)
147+
return impl(left, right, op)
115148
else:
116-
return impl(args[0], args[1])
149+
return impl(left, right)
117150

118151
self._register(key, normalized_impl)
119152
return impl
@@ -177,6 +210,12 @@ def _register(
177210
raise ValueError(f"Operation name {op_name} already registered")
178211
self._registry[op_name] = impl
179212

213+
@classmethod
214+
def _add_parentheses(cls, expr: TypedExpr) -> TypedExpr:
215+
if type(expr.expr) in cls.SQLGLOT_NEEDS_PARENS:
216+
return TypedExpr(sge.paren(expr.expr, copy=False), expr.dtype)
217+
return expr
218+
180219

181220
# Singleton compiler
182221
scalar_op_compiler = ScalarOpCompiler()

0 commit comments

Comments
 (0)