Skip to content

Commit e7db6ed

Browse files
committed
Merge remote-tracking branch 'origin/main' into search-function-18023861678525870072
2 parents 0bd920b + f9b145e commit e7db6ed

File tree

22 files changed

+480
-204
lines changed

22 files changed

+480
-204
lines changed

bigframes/core/block_transforms.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -625,21 +625,7 @@ def skew(
625625
# counts, moment3 for each column
626626
aggregations = []
627627
for col in original_columns:
628-
delta3_expr = _mean_delta_to_power(3, col)
629-
count_agg = agg_expressions.UnaryAggregation(
630-
agg_ops.count_op,
631-
ex.deref(col),
632-
)
633-
moment3_agg = agg_expressions.UnaryAggregation(
634-
agg_ops.mean_op,
635-
delta3_expr,
636-
)
637-
variance_agg = agg_expressions.UnaryAggregation(
638-
agg_ops.PopVarOp(),
639-
ex.deref(col),
640-
)
641-
skew_expr = _skew_from_moments_and_count(count_agg, moment3_agg, variance_agg)
642-
aggregations.append(skew_expr)
628+
aggregations.append(skew_expr(ex.deref(col)))
643629

644630
block = block.aggregate(
645631
aggregations, grouping_column_ids, column_labels=column_labels
@@ -663,16 +649,7 @@ def kurt(
663649
# counts, moment4 for each column
664650
kurt_exprs = []
665651
for col in original_columns:
666-
delta_4_expr = _mean_delta_to_power(4, col)
667-
count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, ex.deref(col))
668-
moment4_agg = agg_expressions.UnaryAggregation(agg_ops.mean_op, delta_4_expr)
669-
variance_agg = agg_expressions.UnaryAggregation(
670-
agg_ops.PopVarOp(), ex.deref(col)
671-
)
672-
673-
# Corresponds to order of aggregations in preceding loop
674-
kurt_expr = _kurt_from_moments_and_count(count_agg, moment4_agg, variance_agg)
675-
kurt_exprs.append(kurt_expr)
652+
kurt_exprs.append(kurt_expr(ex.deref(col)))
676653

677654
block = block.aggregate(
678655
kurt_exprs, grouping_column_ids, column_labels=column_labels
@@ -686,13 +663,38 @@ def kurt(
686663
return block
687664

688665

666+
def skew_expr(expr: ex.Expression) -> ex.Expression:
667+
delta3_expr = _mean_delta_to_power(3, expr)
668+
count_agg = agg_expressions.UnaryAggregation(
669+
agg_ops.count_op,
670+
expr,
671+
)
672+
moment3_agg = agg_expressions.UnaryAggregation(
673+
agg_ops.mean_op,
674+
delta3_expr,
675+
)
676+
variance_agg = agg_expressions.UnaryAggregation(
677+
agg_ops.PopVarOp(),
678+
expr,
679+
)
680+
return _skew_from_moments_and_count(count_agg, moment3_agg, variance_agg)
681+
682+
683+
def kurt_expr(expr: ex.Expression) -> ex.Expression:
684+
delta_4_expr = _mean_delta_to_power(4, expr)
685+
count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, expr)
686+
moment4_agg = agg_expressions.UnaryAggregation(agg_ops.mean_op, delta_4_expr)
687+
variance_agg = agg_expressions.UnaryAggregation(agg_ops.PopVarOp(), expr)
688+
return _kurt_from_moments_and_count(count_agg, moment4_agg, variance_agg)
689+
690+
689691
def _mean_delta_to_power(
690692
n_power: int,
691-
val_id: str,
693+
col_expr: ex.Expression,
692694
) -> ex.Expression:
693695
"""Calculate (x-mean(x))^n. Useful for calculating moment statistics such as skew and kurtosis."""
694-
mean_expr = agg_expressions.UnaryAggregation(agg_ops.mean_op, ex.deref(val_id))
695-
delta = ops.sub_op.as_expr(val_id, mean_expr)
696+
mean_expr = agg_expressions.UnaryAggregation(agg_ops.mean_op, col_expr)
697+
delta = ops.sub_op.as_expr(col_expr, mean_expr)
696698
return ops.pow_op.as_expr(delta, ex.const(n_power))
697699

698700

bigframes/core/bq_data.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,21 @@ def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable:
6464
else tuple(table.clustering_fields),
6565
)
6666

67+
@staticmethod
68+
def from_ref_and_schema(
69+
table_ref: bq.TableReference,
70+
schema: Sequence[bq.SchemaField],
71+
cluster_cols: Optional[Sequence[str]] = None,
72+
) -> GbqTable:
73+
return GbqTable(
74+
project_id=table_ref.project,
75+
dataset_id=table_ref.dataset_id,
76+
table_id=table_ref.table_id,
77+
physical_schema=tuple(schema),
78+
is_physically_stored=True,
79+
cluster_cols=tuple(cluster_cols) if cluster_cols else None,
80+
)
81+
6782
def get_table_ref(self) -> bq.TableReference:
6883
return bq.TableReference(
6984
bq.DatasetReference(self.project_id, self.dataset_id), self.table_id

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI
378378
window_op = sge.Case(ifs=when_expressions, default=window_op)
379379

380380
# TODO: check if we can directly window the expression.
381-
result = child.window(
381+
result = result.window(
382382
window_op=window_op,
383383
output_column_id=cdef.id.sql,
384384
)

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def _construct_prompt(
9393
for elem in prompt_context:
9494
if elem is None:
9595
prompt.append(exprs[column_ref_idx].expr)
96+
column_ref_idx += 1
9697
else:
9798
prompt.append(sge.Literal.string(elem))
9899

bigframes/core/compile/sqlglot/expressions/json_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def _(expr: TypedExpr) -> sge.Expression:
6969
return sge.func("PARSE_JSON", expr.expr)
7070

7171

72+
@register_unary_op(ops.ToJSON)
73+
def _(expr: TypedExpr) -> sge.Expression:
74+
return sge.func("TO_JSON", expr.expr)
75+
76+
7277
@register_unary_op(ops.ToJSONString)
7378
def _(expr: TypedExpr) -> sge.Expression:
7479
return sge.func("TO_JSON_STRING", expr.expr)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@
2121

2222
from google.cloud import bigquery
2323
import numpy as np
24+
import pandas as pd
2425
import pyarrow as pa
2526
import sqlglot as sg
2627
import sqlglot.dialects.bigquery
2728
import sqlglot.expressions as sge
2829

2930
from bigframes import dtypes
3031
from bigframes.core import guid, local_data, schema, utils
31-
from bigframes.core.compile.sqlglot.expressions import typed_expr
32+
from bigframes.core.compile.sqlglot.expressions import constants, typed_expr
3233
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
3334

3435
# shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0.
@@ -639,12 +640,30 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
639640
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
640641
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
641642
if sqlglot_type is None:
642-
if value is not None:
643-
raise ValueError("Cannot infer SQLGlot type from None dtype.")
643+
if not pd.isna(value):
644+
raise ValueError(f"Cannot infer SQLGlot type from None dtype: {value}")
644645
return sge.Null()
645646

646647
if value is None:
647648
return _cast(sge.Null(), sqlglot_type)
649+
if dtypes.is_struct_like(dtype):
650+
items = [
651+
_literal(value=value[field_name], dtype=field_dtype).as_(
652+
field_name, quoted=True
653+
)
654+
for field_name, field_dtype in dtypes.get_struct_fields(dtype).items()
655+
]
656+
return sge.Struct.from_arg_list(items)
657+
elif dtypes.is_array_like(dtype):
658+
value_type = dtypes.get_array_inner_type(dtype)
659+
values = sge.Array(
660+
expressions=[_literal(value=v, dtype=value_type) for v in value]
661+
)
662+
return values if len(value) > 0 else _cast(values, sqlglot_type)
663+
elif pd.isna(value):
664+
return _cast(sge.Null(), sqlglot_type)
665+
elif dtype == dtypes.JSON_DTYPE:
666+
return sge.ParseJSON(this=sge.convert(str(value)))
648667
elif dtype == dtypes.BYTES_DTYPE:
649668
return _cast(str(value), sqlglot_type)
650669
elif dtypes.is_time_like(dtype):
@@ -658,24 +677,12 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
658677
elif dtypes.is_geo_like(dtype):
659678
wkt = value if isinstance(value, str) else to_wkt(value)
660679
return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt))
661-
elif dtype == dtypes.JSON_DTYPE:
662-
return sge.ParseJSON(this=sge.convert(str(value)))
663680
elif dtype == dtypes.TIMEDELTA_DTYPE:
664681
return sge.convert(utils.timedelta_to_micros(value))
665-
elif dtypes.is_struct_like(dtype):
666-
items = [
667-
_literal(value=value[field_name], dtype=field_dtype).as_(
668-
field_name, quoted=True
669-
)
670-
for field_name, field_dtype in dtypes.get_struct_fields(dtype).items()
671-
]
672-
return sge.Struct.from_arg_list(items)
673-
elif dtypes.is_array_like(dtype):
674-
value_type = dtypes.get_array_inner_type(dtype)
675-
values = sge.Array(
676-
expressions=[_literal(value=v, dtype=value_type) for v in value]
677-
)
678-
return values if len(value) > 0 else _cast(values, sqlglot_type)
682+
elif dtype == dtypes.FLOAT_DTYPE:
683+
if np.isinf(value):
684+
return constants._INF if value > 0 else constants._NEG_INF
685+
return sge.convert(value)
679686
else:
680687
if isinstance(value, np.generic):
681688
value = value.item()

bigframes/core/expression.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
from __future__ import annotations
1616

1717
import abc
18-
import collections
1918
import dataclasses
2019
import functools
2120
import itertools
2221
import typing
23-
from typing import Callable, Dict, Generator, Mapping, Tuple, TypeVar, Union
22+
from typing import Callable, Generator, Mapping, TypeVar, Union
2423

2524
import pandas as pd
2625

@@ -162,57 +161,6 @@ def walk(self) -> Generator[Expression, None, None]:
162161
for child in self.children:
163162
yield from child.children
164163

165-
def unique_nodes(
166-
self: Expression,
167-
) -> Generator[Expression, None, None]:
168-
"""Walks the tree for unique nodes"""
169-
seen = set()
170-
stack: list[Expression] = [self]
171-
while stack:
172-
item = stack.pop()
173-
if item not in seen:
174-
yield item
175-
seen.add(item)
176-
stack.extend(item.children)
177-
178-
def iter_nodes_topo(
179-
self: Expression,
180-
) -> Generator[Expression, None, None]:
181-
"""Returns nodes in reverse topological order, using Kahn's algorithm."""
182-
child_to_parents: Dict[Expression, list[Expression]] = collections.defaultdict(
183-
list
184-
)
185-
out_degree: Dict[Expression, int] = collections.defaultdict(int)
186-
187-
queue: collections.deque["Expression"] = collections.deque()
188-
for node in list(self.unique_nodes()):
189-
num_children = len(node.children)
190-
out_degree[node] = num_children
191-
if num_children == 0:
192-
queue.append(node)
193-
for child in node.children:
194-
child_to_parents[child].append(node)
195-
196-
while queue:
197-
item = queue.popleft()
198-
yield item
199-
parents = child_to_parents.get(item, [])
200-
for parent in parents:
201-
out_degree[parent] -= 1
202-
if out_degree[parent] == 0:
203-
queue.append(parent)
204-
205-
def reduce_up(self, reduction: Callable[[Expression, Tuple[T, ...]], T]) -> T:
206-
"""Apply a bottom-up reduction to the tree."""
207-
results: dict[Expression, T] = {}
208-
for node in list(self.iter_nodes_topo()):
209-
# child nodes have already been transformed
210-
child_results = tuple(results[child] for child in node.children)
211-
result = reduction(node, child_results)
212-
results[node] = result
213-
214-
return results[self]
215-
216164

217165
@dataclasses.dataclass(frozen=True)
218166
class ScalarConstantExpression(Expression):

0 commit comments

Comments
 (0)