Skip to content

Commit bce9a69

Browse files
authored
Merge branch 'main' into main_chelsealin_compilejoin
2 parents 7d3532f + 9fb3cb4 commit bce9a69

File tree

37 files changed

+1542
-63
lines changed

37 files changed

+1542
-63
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ repos:
3939
rev: v1.15.0
4040
hooks:
4141
- id: mypy
42-
additional_dependencies: [types-requests, types-tabulate, pandas-stubs<=2.2.3.241126]
42+
additional_dependencies: [types-requests, types-tabulate, types-PyYAML, pandas-stubs<=2.2.3.241126]
4343
exclude: "^third_party"
4444
args: ["--check-untyped-defs", "--explicit-package-bases", "--ignore-missing-imports"]

bigframes/clients.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import google.api_core.exceptions
2525
import google.api_core.retry
2626
from google.cloud import bigquery_connection_v1, resourcemanager_v3
27+
from google.iam.v1 import policy_pb2
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -172,10 +173,7 @@ def _ensure_iam_binding(
172173
return
173174

174175
# Create a new binding
175-
new_binding = {
176-
"role": role,
177-
"members": [service_account],
178-
} # Use a dictionary to avoid problematic google.iam namespace package.
176+
new_binding = policy_pb2.Binding(role=role, members=[service_account])
179177
policy.bindings.append(new_binding)
180178
request = {
181179
"resource": project,

bigframes/core/compile/polars/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323

2424
import bigframes.core
2525
from bigframes.core import identifiers, nodes, ordering, window_spec
26+
from bigframes.core.compile.polars import lowering
2627
import bigframes.core.expression as ex
2728
import bigframes.core.guid as guid
2829
import bigframes.core.rewrite
30+
import bigframes.core.rewrite.schema_binding
2931
import bigframes.dtypes
3032
import bigframes.operations as ops
3133
import bigframes.operations.aggregations as agg_ops
@@ -403,6 +405,8 @@ def compile(self, array_value: bigframes.core.ArrayValue) -> pl.LazyFrame:
403405
node = bigframes.core.rewrite.column_pruning(node)
404406
node = nodes.bottom_up(node, bigframes.core.rewrite.rewrite_slice)
405407
node = bigframes.core.rewrite.pull_out_window_order(node)
408+
node = bigframes.core.rewrite.schema_binding.bind_schema_to_tree(node)
409+
node = lowering.lower_ops_to_polars(node)
406410
return self.compile_node(node)
407411

408412
@functools.singledispatchmethod
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from bigframes import dtypes
16+
from bigframes.core import bigframe_node, expression
17+
from bigframes.core.rewrite import op_lowering
18+
from bigframes.operations import numeric_ops
19+
import bigframes.operations as ops
20+
21+
# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)
22+
23+
24+
class LowerFloorDivRule(op_lowering.OpLoweringRule):
25+
@property
26+
def op(self) -> type[ops.ScalarOp]:
27+
return numeric_ops.FloorDivOp
28+
29+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
30+
dividend = expr.children[0]
31+
divisor = expr.children[1]
32+
using_floats = (dividend.output_type == dtypes.FLOAT_DTYPE) or (
33+
divisor.output_type == dtypes.FLOAT_DTYPE
34+
)
35+
inf_or_zero = (
36+
expression.const(float("INF")) if using_floats else expression.const(0)
37+
)
38+
zero_result = ops.mul_op.as_expr(inf_or_zero, dividend)
39+
divisor_is_zero = ops.eq_op.as_expr(divisor, expression.const(0))
40+
return ops.where_op.as_expr(zero_result, divisor_is_zero, expr)
41+
42+
43+
POLARS_LOWERING_RULES = (LowerFloorDivRule(),)
44+
45+
46+
def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:
47+
return op_lowering.lower_ops(root, rules=POLARS_LOWERING_RULES)

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ def compile_concat(
244244
uid_gen=self.uid_gen,
245245
)
246246

247+
@_compile_node.register
248+
def compile_explode(
249+
self, node: nodes.ExplodeNode, child: ir.SQLGlotIR
250+
) -> ir.SQLGlotIR:
251+
offsets_col = node.offsets_col.sql if (node.offsets_col is not None) else None
252+
columns = tuple(ref.id.sql for ref in node.column_ids)
253+
return child.explode(columns, offsets_col)
254+
247255

248256
def _replace_unsupported_ops(node: nodes.BigFrameNode):
249257
node = nodes.bottom_up(node, rewrite.rewrite_slice)

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,22 @@
1414

1515
from __future__ import annotations
1616

17-
import typing
18-
1917
import sqlglot.expressions as sge
2018

2119
from bigframes import dtypes
2220
from bigframes import operations as ops
2321
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2422
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2523

26-
BinaryOpCompiler = typing.Callable[[ops.BinaryOp, TypedExpr, TypedExpr], sge.Expression]
27-
28-
BINARY_OP_REIGSTRATION = OpRegistration[BinaryOpCompiler]()
24+
BINARY_OP_REGISTRATION = OpRegistration()
2925

3026

3127
def compile(op: ops.BinaryOp, left: TypedExpr, right: TypedExpr) -> sge.Expression:
32-
return BINARY_OP_REIGSTRATION[op](op, left, right)
28+
return BINARY_OP_REGISTRATION[op](op, left, right)
3329

3430

3531
# TODO: add parenthesize for operators
36-
@BINARY_OP_REIGSTRATION.register(ops.add_op)
32+
@BINARY_OP_REGISTRATION.register(ops.add_op)
3733
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3834
if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE:
3935
# String addition
@@ -43,7 +39,6 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
4339
return sge.Add(this=left.expr, expression=right.expr)
4440

4541

46-
@BINARY_OP_REIGSTRATION.register(ops.ge_op)
47-
def compile_ge(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
48-
42+
@BINARY_OP_REGISTRATION.register(ops.ge_op)
43+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
4944
return sge.GTE(this=left.expr, expression=right.expr)

bigframes/core/compile/sqlglot/expressions/nary_compiler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,14 @@
1414

1515
from __future__ import annotations
1616

17-
import typing
18-
1917
import sqlglot.expressions as sge
2018

2119
from bigframes import operations as ops
2220
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2321
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2422

25-
# No simpler way to specify that the compilation function expects varargs.
26-
NaryOpCompiler = typing.Callable[..., sge.Expression]
27-
28-
NARY_OP_REIGSTRATION = OpRegistration[NaryOpCompiler]()
23+
NARY_OP_REGISTRATION = OpRegistration()
2924

3025

3126
def compile(op: ops.NaryOp, *args: TypedExpr) -> sge.Expression:
32-
return NARY_OP_REIGSTRATION[op](op, *args)
27+
return NARY_OP_REGISTRATION[op](op, *args)

bigframes/core/compile/sqlglot/expressions/op_registration.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,40 @@
1515
from __future__ import annotations
1616

1717
import typing
18-
from typing import Generic, TypeVar
18+
19+
from sqlglot import expressions as sge
1920

2021
from bigframes import operations as ops
2122

22-
T = TypeVar("T")
23+
# We should've been more specific about input types. Unfortunately,
24+
# MyPy doesn't support more rigorous checks.
25+
CompilationFunc = typing.Callable[..., sge.Expression]
2326

2427

25-
class OpRegistration(Generic[T]):
26-
_registered_ops: dict[str, T] = {}
28+
class OpRegistration:
29+
def __init__(self) -> None:
30+
self._registered_ops: dict[str, CompilationFunc] = {}
2731

2832
def register(
2933
self, op: ops.ScalarOp | type[ops.ScalarOp]
30-
) -> typing.Callable[[T], T]:
31-
key = typing.cast(str, op.name)
32-
33-
def decorator(item: T):
34+
) -> typing.Callable[[CompilationFunc], CompilationFunc]:
35+
def decorator(item: CompilationFunc):
36+
def arg_checker(*args, **kwargs):
37+
if not isinstance(args[0], ops.ScalarOp):
38+
raise ValueError(
39+
f"The first parameter must be an operator. Got {type(args[0])}"
40+
)
41+
return item(*args, **kwargs)
42+
43+
key = typing.cast(str, op.name)
3444
if key in self._registered_ops:
3545
raise ValueError(f"{key} is already registered")
3646
self._registered_ops[key] = item
37-
return item
47+
return arg_checker
3848

3949
return decorator
4050

41-
def __getitem__(self, key: str | ops.ScalarOp) -> T:
51+
def __getitem__(self, key: str | ops.ScalarOp) -> CompilationFunc:
4252
if isinstance(key, ops.ScalarOp):
4353
return self._registered_ops[key.name]
4454
return self._registered_ops[key]

bigframes/core/compile/sqlglot/expressions/ternary_compiler.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,16 @@
1414

1515
from __future__ import annotations
1616

17-
import typing
18-
1917
import sqlglot.expressions as sge
2018

2119
from bigframes import operations as ops
2220
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2321
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2422

25-
TernaryOpCompiler = typing.Callable[
26-
[ops.TernaryOp, TypedExpr, TypedExpr, TypedExpr], sge.Expression
27-
]
28-
29-
TERNATRY_OP_REIGSTRATION = OpRegistration[TernaryOpCompiler]()
23+
TERNATRY_OP_REGISTRATION = OpRegistration()
3024

3125

3226
def compile(
3327
op: ops.TernaryOp, expr1: TypedExpr, expr2: TypedExpr, expr3: TypedExpr
3428
) -> sge.Expression:
35-
return TERNATRY_OP_REIGSTRATION[op](op, expr1, expr2, expr3)
29+
return TERNATRY_OP_REGISTRATION[op](op, expr1, expr2, expr3)

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,57 @@
1616

1717
import typing
1818

19+
import sqlglot
1920
import sqlglot.expressions as sge
2021

2122
from bigframes import operations as ops
2223
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2324
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2425

25-
UnaryOpCompiler = typing.Callable[[ops.UnaryOp, TypedExpr], sge.Expression]
26-
27-
UNARY_OP_REIGSTRATION = OpRegistration[UnaryOpCompiler]()
26+
UNARY_OP_REGISTRATION = OpRegistration()
2827

2928

3029
def compile(op: ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
31-
return UNARY_OP_REIGSTRATION[op](op, expr)
30+
return UNARY_OP_REGISTRATION[op](op, expr)
31+
32+
33+
@UNARY_OP_REGISTRATION.register(ops.ArrayToStringOp)
34+
def _(op: ops.ArrayToStringOp, expr: TypedExpr) -> sge.Expression:
35+
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
36+
37+
38+
@UNARY_OP_REGISTRATION.register(ops.ArrayIndexOp)
39+
def _(op: ops.ArrayIndexOp, expr: TypedExpr) -> sge.Expression:
40+
return sge.Bracket(
41+
this=expr.expr,
42+
expressions=[sge.Literal.number(op.index)],
43+
safe=True,
44+
offset=False,
45+
)
46+
47+
48+
@UNARY_OP_REGISTRATION.register(ops.ArraySliceOp)
49+
def _(op: ops.ArraySliceOp, expr: TypedExpr) -> sge.Expression:
50+
slice_idx = sqlglot.to_identifier("slice_idx")
51+
52+
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
53+
54+
if op.stop is not None:
55+
conditions.append(slice_idx < op.stop)
56+
57+
# local name for each element in the array
58+
el = sqlglot.to_identifier("el")
59+
60+
selected_elements = (
61+
sge.select(el)
62+
.from_(
63+
sge.Unnest(
64+
expressions=[expr.expr],
65+
alias=sge.TableAlias(columns=[el]),
66+
offset=slice_idx,
67+
)
68+
)
69+
.where(*conditions)
70+
)
71+
72+
return sge.array(selected_elements)

0 commit comments

Comments
 (0)