Skip to content

Commit 9b4b18a

Browse files
committed
Merge branch 'main' into shuowei-anywidget-html-repr
2 parents 05a9245 + ed79146 commit 9b4b18a

File tree

22 files changed

+371
-12
lines changed

22 files changed

+371
-12
lines changed

bigframes/bigquery/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
json_extract,
4848
json_extract_array,
4949
json_extract_string_array,
50+
json_keys,
5051
json_query,
5152
json_query_array,
5253
json_set,
@@ -138,6 +139,7 @@
138139
"json_extract",
139140
"json_extract_array",
140141
"json_extract_string_array",
142+
"json_keys",
141143
"json_query",
142144
"json_query_array",
143145
"json_set",

bigframes/bigquery/_operations/json.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,35 @@ def json_value_array(
421421
return input._apply_unary_op(ops.JSONValueArray(json_path=json_path))
422422

423423

424+
def json_keys(
425+
input: series.Series,
426+
max_depth: Optional[int] = None,
427+
) -> series.Series:
428+
"""Returns all keys in the root of a JSON object as an ARRAY of STRINGs.
429+
430+
**Examples:**
431+
432+
>>> import bigframes.pandas as bpd
433+
>>> import bigframes.bigquery as bbq
434+
435+
>>> s = bpd.Series(['{"b": {"c": 2}, "a": 1}'], dtype="json")
436+
>>> bbq.json_keys(s)
437+
0 ['a' 'b' 'b.c']
438+
dtype: list<item: string>[pyarrow]
439+
440+
Args:
441+
input (bigframes.series.Series):
442+
The Series containing JSON data.
443+
max_depth (int, optional):
444+
Specifies the maximum depth of nested fields to search for keys. If not
445+
provided, searched keys at all levels.
446+
447+
Returns:
448+
bigframes.series.Series: A new Series containing arrays of keys from the input JSON.
449+
"""
450+
return input._apply_unary_op(ops.JSONKeys(max_depth=max_depth))
451+
452+
424453
def to_json(
425454
input: series.Series,
426455
) -> series.Series:

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,11 @@ def json_value_array_op_impl(x: ibis_types.Value, op: ops.JSONValueArray):
12341234
return json_value_array(json_obj=x, json_path=op.json_path)
12351235

12361236

1237+
@scalar_op_compiler.register_unary_op(ops.JSONKeys, pass_op=True)
1238+
def json_keys_op_impl(x: ibis_types.Value, op: ops.JSONKeys):
1239+
return json_keys(x, op.max_depth)
1240+
1241+
12371242
# Blob Ops
12381243
@scalar_op_compiler.register_unary_op(ops.obj_fetch_metadata_op)
12391244
def obj_fetch_metadata_op_impl(obj_ref: ibis_types.Value):
@@ -2059,6 +2064,14 @@ def to_json_string(value) -> ibis_dtypes.String: # type: ignore[empty-body]
20592064
"""Convert value to JSON-formatted string."""
20602065

20612066

2067+
@ibis_udf.scalar.builtin(name="json_keys")
2068+
def json_keys( # type: ignore[empty-body]
2069+
json_obj: ibis_dtypes.JSON,
2070+
max_depth: ibis_dtypes.Int64,
2071+
) -> ibis_dtypes.Array[ibis_dtypes.String]:
2072+
"""Extracts unique JSON keys from a JSON expression."""
2073+
2074+
20622075
@ibis_udf.scalar.builtin(name="json_value")
20632076
def json_value( # type: ignore[empty-body]
20642077
json_obj: ibis_dtypes.JSON, json_path: ibis_dtypes.String

bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,5 @@ def _(
4949
result: sge.Expression = sge.func("ROW_NUMBER")
5050
if window is None:
5151
# ROW_NUMBER always needs an OVER clause.
52-
return sge.Window(this=result)
53-
return apply_window_if_present(result, window, include_framing_clauses=False)
52+
return sge.Window(this=result) - 1
53+
return apply_window_if_present(result, window, include_framing_clauses=False) - 1

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,43 @@ def _(
400400
return apply_window_if_present(expr, window)
401401

402402

403+
@UNARY_OP_REGISTRATION.register(agg_ops.QcutOp)
404+
def _(
405+
op: agg_ops.QcutOp,
406+
column: typed_expr.TypedExpr,
407+
window: typing.Optional[window_spec.WindowSpec] = None,
408+
) -> sge.Expression:
409+
percent_ranks_order_by = sge.Ordered(this=column.expr, desc=False)
410+
percent_ranks = apply_window_if_present(
411+
sge.func("PERCENT_RANK"),
412+
window,
413+
include_framing_clauses=False,
414+
order_by_override=[percent_ranks_order_by],
415+
)
416+
if isinstance(op.quantiles, int):
417+
scaled_rank = percent_ranks * sge.convert(op.quantiles)
418+
# Calculate the 0-based bucket index.
419+
bucket_index = sge.func("CEIL", scaled_rank) - sge.convert(1)
420+
safe_bucket_index = sge.func("GREATEST", bucket_index, 0)
421+
422+
return sge.If(
423+
this=sge.Is(this=column.expr, expression=sge.Null()),
424+
true=sge.Null(),
425+
false=sge.Cast(this=safe_bucket_index, to="INT64"),
426+
)
427+
else:
428+
case = sge.Case()
429+
first_quantile = sge.convert(op.quantiles[0])
430+
case = case.when(
431+
sge.LT(this=percent_ranks, expression=first_quantile), sge.Null()
432+
)
433+
for bucket_n in range(len(op.quantiles) - 1):
434+
quantile = sge.convert(op.quantiles[bucket_n + 1])
435+
bucket = sge.convert(bucket_n)
436+
case = case.when(sge.LTE(this=percent_ranks, expression=quantile), bucket)
437+
return case.else_(sge.Null())
438+
439+
403440
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
404441
def _(
405442
op: agg_ops.QuantileOp,

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def apply_window_if_present(
2626
value: sge.Expression,
2727
window: typing.Optional[window_spec.WindowSpec] = None,
2828
include_framing_clauses: bool = True,
29+
order_by_override: typing.Optional[typing.List[sge.Ordered]] = None,
2930
) -> sge.Expression:
3031
if window is None:
3132
return value
@@ -44,7 +45,11 @@ def apply_window_if_present(
4445
else:
4546
order_by = get_window_order_by(window.ordering)
4647

47-
order = sge.Order(expressions=order_by) if order_by else None
48+
order = None
49+
if order_by_override is not None and len(order_by_override) > 0:
50+
order = sge.Order(expressions=order_by_override)
51+
elif order_by:
52+
order = sge.Order(expressions=order_by)
4853

4954
group_by = (
5055
[

bigframes/core/compile/sqlglot/expressions/json_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def _(expr: TypedExpr, op: ops.JSONExtractStringArray) -> sge.Expression:
3939
return sge.func("JSON_EXTRACT_STRING_ARRAY", expr.expr, sge.convert(op.json_path))
4040

4141

42+
@register_unary_op(ops.JSONKeys, pass_op=True)
43+
def _(expr: TypedExpr, op: ops.JSONKeys) -> sge.Expression:
44+
return sge.func("JSON_KEYS", expr.expr, sge.convert(op.max_depth))
45+
46+
4247
@register_unary_op(ops.JSONQuery, pass_op=True)
4348
def _(expr: TypedExpr, op: ops.JSONQuery) -> sge.Expression:
4449
return sge.func("JSON_QUERY", expr.expr, sge.convert(op.json_path))

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,14 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
443443
)
444444

445445

446+
@register_binary_op(ops.unsafe_pow_op)
447+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
448+
"""For internal use only - where domain and overflow checks are not needed."""
449+
left_expr = _coerce_bool_to_int(left)
450+
right_expr = _coerce_bool_to_int(right)
451+
return sge.Pow(this=left_expr, expression=right_expr)
452+
453+
446454
@register_unary_op(numeric_ops.isnan_op)
447455
def isnan(arg: TypedExpr) -> sge.Expression:
448456
return sge.IsNan(this=arg.expr)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,12 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
637637

638638

639639
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
640-
sqlglot_type = sgt.from_bigframes_dtype(dtype)
640+
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
641+
if sqlglot_type is None:
642+
if value is not None:
643+
raise ValueError("Cannot infer SQLGlot type from None dtype.")
644+
return sge.Null()
645+
641646
if value is None:
642647
return _cast(sge.Null(), sqlglot_type)
643648
elif dtype == dtypes.BYTES_DTYPE:

bigframes/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@
128128
JSONExtract,
129129
JSONExtractArray,
130130
JSONExtractStringArray,
131+
JSONKeys,
131132
JSONQuery,
132133
JSONQueryArray,
133134
JSONSet,
@@ -381,6 +382,7 @@
381382
"JSONExtract",
382383
"JSONExtractArray",
383384
"JSONExtractStringArray",
385+
"JSONKeys",
384386
"JSONQuery",
385387
"JSONQueryArray",
386388
"JSONSet",

0 commit comments

Comments
 (0)