Skip to content

Commit d7ddb9b

Browse files
Merge remote-tracking branch 'github/main' into crosstab
2 parents c2e4450 + 0d7d7e4 commit d7ddb9b

File tree

17 files changed

+219
-68
lines changed

17 files changed

+219
-68
lines changed

.github/workflows/docs.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ on:
22
pull_request:
33
branches:
44
- main
5+
push:
6+
branches:
7+
- main
58
name: docs
69
jobs:
710
docs:

.github/workflows/lint.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ on:
22
pull_request:
33
branches:
44
- main
5+
push:
6+
branches:
7+
- main
58
name: lint
69
jobs:
710
lint:

.github/workflows/mypy.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ on:
22
pull_request:
33
branches:
44
- main
5+
push:
6+
branches:
7+
- main
58
name: mypy
69
jobs:
710
mypy:

.github/workflows/unittest.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ on:
22
pull_request:
33
branches:
44
- main
5+
push:
6+
branches:
7+
- main
58
name: unittest
69
jobs:
710
unit:

bigframes/core/compile/sqlglot/expressions/geo_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2222

2323
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
24+
register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op
2425

2526

2627
@register_unary_op(ops.geo_area_op)
@@ -108,3 +109,8 @@ def _(expr: TypedExpr) -> sge.Expression:
108109
@register_unary_op(ops.geo_y_op)
109110
def _(expr: TypedExpr) -> sge.Expression:
110111
return sge.func("SAFE.ST_Y", expr.expr)
112+
113+
114+
@register_binary_op(ops.geo_st_difference_op)
115+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
116+
return sge.func("ST_DIFFERENCE", left.expr, right.expr)

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ def _(expr: TypedExpr) -> sge.Expression:
7777
return sge.func("ASINH", expr.expr)
7878

7979

80+
@register_binary_op(ops.arctan2_op)
81+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
82+
left_expr = _coerce_bool_to_int(left)
83+
right_expr = _coerce_bool_to_int(right)
84+
return sge.func("ATAN2", left_expr, right_expr)
85+
86+
8087
@register_unary_op(ops.arctan_op)
8188
def _(expr: TypedExpr) -> sge.Expression:
8289
return sge.func("ATAN", expr.expr)
@@ -118,6 +125,18 @@ def _(expr: TypedExpr) -> sge.Expression:
118125
)
119126

120127

128+
@register_binary_op(ops.cosine_distance_op)
129+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
130+
return sge.Anonymous(
131+
this="ML.DISTANCE",
132+
expressions=[
133+
left.expr,
134+
right.expr,
135+
sge.Literal.string("COSINE"),
136+
],
137+
)
138+
139+
121140
@register_unary_op(ops.exp_op)
122141
def _(expr: TypedExpr) -> sge.Expression:
123142
return sge.Case(

bigframes/core/local_data.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,16 @@ def _(
253253
value_generator = iter_array(
254254
array.flatten(), bigframes.dtypes.get_array_inner_type(dtype)
255255
)
256-
for (start, end) in _pairwise(array.offsets):
257-
arr_size = end.as_py() - start.as_py()
258-
yield list(itertools.islice(value_generator, arr_size))
256+
offset_generator = iter_array(array.offsets, bigframes.dtypes.INT_DTYPE)
257+
258+
start_offset = None
259+
end_offset = None
260+
for offset in offset_generator:
261+
start_offset = end_offset
262+
end_offset = offset
263+
if start_offset is not None:
264+
arr_size = end_offset - start_offset
265+
yield list(itertools.islice(value_generator, arr_size))
259266

260267
@iter_array.register
261268
def _(
@@ -267,8 +274,15 @@ def _(
267274
sub_generators[field_name] = iter_array(array.field(field_name), dtype)
268275

269276
keys = list(sub_generators.keys())
270-
for row_values in zip(*sub_generators.values()):
271-
yield {key: value for key, value in zip(keys, row_values)}
277+
is_null_generator = iter_array(array.is_null(), bigframes.dtypes.BOOL_DTYPE)
278+
279+
for values in zip(is_null_generator, *sub_generators.values()):
280+
is_row_null = values[0]
281+
row_values = values[1:]
282+
if not is_row_null:
283+
yield {key: value for key, value in zip(keys, row_values)}
284+
else:
285+
yield None
272286

273287
for batch in table.to_batches():
274288
sub_generators: dict[str, Generator[Any, None, None]] = {}
@@ -491,16 +505,3 @@ def _schema_durations_to_ints(schema: pa.Schema) -> pa.Schema:
491505
return pa.schema(
492506
pa.field(field.name, _durations_to_ints(field.type)) for field in schema
493507
)
494-
495-
496-
def _pairwise(iterable):
497-
do_yield = False
498-
a = None
499-
b = None
500-
for item in iterable:
501-
a = b
502-
b = item
503-
if do_yield:
504-
yield (a, b)
505-
else:
506-
do_yield = True

owlbot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@
5858
".kokoro/build.sh",
5959
".kokoro/continuous/common.cfg",
6060
".kokoro/presubmit/common.cfg",
61-
# Temporary workaround to update docs job to use python 3.10
6261
".github/workflows/docs.yml",
62+
".github/workflows/lint.yml",
63+
".github/workflows/unittest.yml",
6364
],
6465
)
6566

tests/system/small/engines/test_read_local.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def test_engines_read_local_w_zero_row_source(
8888
assert_equivalence_execution(local_node, REFERENCE_ENGINE, engine)
8989

9090

91-
# TODO: Fix sqlglot impl
92-
@pytest.mark.parametrize("engine", ["polars", "bq", "pyarrow"], indirect=True)
91+
@pytest.mark.parametrize(
92+
"engine", ["polars", "bq", "pyarrow", "bq-sqlglot"], indirect=True
93+
)
9394
def test_engines_read_local_w_nested_source(
9495
fake_session: bigframes.Session,
9596
nested_data_source: local_data.ManagedArrowTable,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`geography_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
ST_DIFFERENCE(`bfcol_0`, `bfcol_0`) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `geography_col`
13+
FROM `bfcte_1`

0 commit comments

Comments
 (0)