Skip to content

Commit 91c32a9

Browse files
Merge branch 'main' into more_local_numerics
2 parents 1ffd2b4 + feb3ff4 commit 91c32a9

File tree

19 files changed

+636
-52
lines changed

19 files changed

+636
-52
lines changed

GEMINI.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ We use `nox` to instrument our tests.
4545
nox -r -s lint
4646
```
4747

48+
- When writing tests, use the idiomatic "pytest" style.
49+
4850
## Documentation
4951

5052
If a method or property is implementing the same interface as a third-party

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import bigframes_vendored.constants as constants
1718
import sqlglot.expressions as sge
1819

1920
from bigframes import dtypes
@@ -35,8 +36,83 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3536
# String addition
3637
return sge.Concat(expressions=[left.expr, right.expr])
3738

38-
# Numerical addition
39-
return sge.Add(this=left.expr, expression=right.expr)
39+
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
40+
left_expr = left.expr
41+
if left.dtype == dtypes.BOOL_DTYPE:
42+
left_expr = sge.Cast(this=left_expr, to="INT64")
43+
right_expr = right.expr
44+
if right.dtype == dtypes.BOOL_DTYPE:
45+
right_expr = sge.Cast(this=right_expr, to="INT64")
46+
return sge.Add(this=left_expr, expression=right_expr)
47+
48+
if (
49+
dtypes.is_time_or_date_like(left.dtype)
50+
and right.dtype == dtypes.TIMEDELTA_DTYPE
51+
):
52+
left_expr = left.expr
53+
if left.dtype == dtypes.DATE_DTYPE:
54+
left_expr = sge.Cast(this=left_expr, to="DATETIME")
55+
return sge.TimestampAdd(
56+
this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND")
57+
)
58+
if (
59+
dtypes.is_time_or_date_like(right.dtype)
60+
and left.dtype == dtypes.TIMEDELTA_DTYPE
61+
):
62+
right_expr = right.expr
63+
if right.dtype == dtypes.DATE_DTYPE:
64+
right_expr = sge.Cast(this=right_expr, to="DATETIME")
65+
return sge.TimestampAdd(
66+
this=right_expr, expression=left.expr, unit=sge.Var(this="MICROSECOND")
67+
)
68+
if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE:
69+
return sge.Add(this=left.expr, expression=right.expr)
70+
71+
raise TypeError(
72+
f"Cannot add type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}"
73+
)
74+
75+
76+
@BINARY_OP_REGISTRATION.register(ops.sub_op)
77+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
78+
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
79+
left_expr = left.expr
80+
if left.dtype == dtypes.BOOL_DTYPE:
81+
left_expr = sge.Cast(this=left_expr, to="INT64")
82+
right_expr = right.expr
83+
if right.dtype == dtypes.BOOL_DTYPE:
84+
right_expr = sge.Cast(this=right_expr, to="INT64")
85+
return sge.Sub(this=left_expr, expression=right_expr)
86+
87+
if (
88+
dtypes.is_time_or_date_like(left.dtype)
89+
and right.dtype == dtypes.TIMEDELTA_DTYPE
90+
):
91+
left_expr = left.expr
92+
if left.dtype == dtypes.DATE_DTYPE:
93+
left_expr = sge.Cast(this=left_expr, to="DATETIME")
94+
return sge.TimestampSub(
95+
this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND")
96+
)
97+
if dtypes.is_time_or_date_like(left.dtype) and dtypes.is_time_or_date_like(
98+
right.dtype
99+
):
100+
left_expr = left.expr
101+
if left.dtype == dtypes.DATE_DTYPE:
102+
left_expr = sge.Cast(this=left_expr, to="DATETIME")
103+
right_expr = right.expr
104+
if right.dtype == dtypes.DATE_DTYPE:
105+
right_expr = sge.Cast(this=right_expr, to="DATETIME")
106+
return sge.TimestampDiff(
107+
this=left_expr, expression=right_expr, unit=sge.Var(this="MICROSECOND")
108+
)
109+
110+
if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE:
111+
return sge.Sub(this=left.expr, expression=right.expr)
112+
113+
raise TypeError(
114+
f"Cannot subtract type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}"
115+
)
40116

41117

42118
@BINARY_OP_REGISTRATION.register(ops.ge_op)

bigframes/dtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,10 @@ def is_time_like(type_: ExpressionType) -> bool:
289289
return type_ in (DATETIME_DTYPE, TIMESTAMP_DTYPE, TIME_DTYPE)
290290

291291

292+
def is_time_or_date_like(type_: ExpressionType) -> bool:
293+
return type_ in (DATE_DTYPE, DATETIME_DTYPE, TIME_DTYPE, TIMESTAMP_DTYPE)
294+
295+
292296
def is_geo_like(type_: ExpressionType) -> bool:
293297
return type_ in (GEO_DTYPE,)
294298

bigframes/functions/_function_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def provision_bq_managed_function(
245245

246246
# Augment user package requirements with any internal package
247247
# requirements.
248-
packages = _utils._get_updated_package_requirements(
248+
packages = _utils.get_updated_package_requirements(
249249
packages, is_row_processor, capture_references, ignore_package_version=True
250250
)
251251
if packages:
@@ -258,7 +258,7 @@ def provision_bq_managed_function(
258258
bq_function_name = name
259259
if not bq_function_name:
260260
# Compute a unique hash representing the user code.
261-
function_hash = _utils._get_hash(func, packages)
261+
function_hash = _utils.get_hash(func, packages)
262262
bq_function_name = _utils.get_bigframes_function_name(
263263
function_hash,
264264
session_id,
@@ -539,12 +539,12 @@ def provision_bq_remote_function(
539539
"""Provision a BigQuery remote function."""
540540
# Augment user package requirements with any internal package
541541
# requirements
542-
package_requirements = _utils._get_updated_package_requirements(
542+
package_requirements = _utils.get_updated_package_requirements(
543543
package_requirements, is_row_processor
544544
)
545545

546546
# Compute a unique hash representing the user code
547-
function_hash = _utils._get_hash(def_, package_requirements)
547+
function_hash = _utils.get_hash(def_, package_requirements)
548548

549549
# If reuse of any existing function with the same name (indicated by the
550550
# same hash of its source code) is not intended, then attach a unique

bigframes/functions/_function_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ def wrapper(func):
597597
bqrf_metadata = _utils.get_bigframes_metadata(
598598
python_output_type=py_sig.return_annotation
599599
)
600-
post_process_routine = _utils._build_unnest_post_routine(
600+
post_process_routine = _utils.build_unnest_post_routine(
601601
py_sig.return_annotation
602602
)
603603
py_sig = py_sig.replace(return_annotation=str)

bigframes/functions/_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_remote_function_locations(bq_location):
6363
return bq_location, cloud_function_region
6464

6565

66-
def _get_updated_package_requirements(
66+
def get_updated_package_requirements(
6767
package_requirements=None,
6868
is_row_processor=False,
6969
capture_references=True,
@@ -105,7 +105,7 @@ def _get_updated_package_requirements(
105105
return requirements
106106

107107

108-
def _clean_up_by_session_id(
108+
def clean_up_by_session_id(
109109
bqclient: bigquery.Client,
110110
gcfclient: functions_v2.FunctionServiceClient,
111111
dataset: bigquery.DatasetReference,
@@ -169,7 +169,7 @@ def _clean_up_by_session_id(
169169
pass
170170

171171

172-
def _get_hash(def_, package_requirements=None):
172+
def get_hash(def_, package_requirements=None):
173173
"Get hash (32 digits alphanumeric) of a function."
174174
# There is a known cell-id sensitivity of the cloudpickle serialization in
175175
# notebooks https://github.com/cloudpipe/cloudpickle/issues/538. Because of
@@ -279,7 +279,7 @@ def get_python_version(is_compat: bool = False) -> str:
279279
return f"python{major}{minor}" if is_compat else f"python-{major}.{minor}"
280280

281281

282-
def _build_unnest_post_routine(py_list_type: type[list]):
282+
def build_unnest_post_routine(py_list_type: type[list]):
283283
sdk_type = function_typing.sdk_array_output_type_from_python_type(py_list_type)
284284
assert sdk_type.array_element_type is not None
285285
inner_sdk_type = sdk_type.array_element_type

bigframes/functions/function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _try_import_routine(
9090
return BigqueryCallableRoutine(
9191
udf_def,
9292
session,
93-
post_routine=_utils._build_unnest_post_routine(override_type),
93+
post_routine=_utils.build_unnest_post_routine(override_type),
9494
)
9595
return BigqueryCallableRoutine(udf_def, session, is_managed=not is_remote)
9696

@@ -107,7 +107,7 @@ def _try_import_row_routine(
107107
return BigqueryCallableRowRoutine(
108108
udf_def,
109109
session,
110-
post_routine=_utils._build_unnest_post_routine(override_type),
110+
post_routine=_utils.build_unnest_post_routine(override_type),
111111
)
112112
return BigqueryCallableRowRoutine(udf_def, session, is_managed=not is_remote)
113113

bigframes/pandas/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def clean_up_by_session_id(
293293
session.bqclient, dataset, session_id
294294
)
295295

296-
bff_utils._clean_up_by_session_id(
296+
bff_utils.clean_up_by_session_id(
297297
session.bqclient, session.cloudfunctionsclient, dataset, session_id
298298
)
299299

bigframes/session/metrics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def count_job_stats(
4242
assert row_iterator is not None
4343

4444
# TODO(tswast): Pass None after making benchmark publishing robust to missing data.
45-
bytes_processed = getattr(row_iterator, "total_bytes_processed", 0)
46-
query_char_count = len(getattr(row_iterator, "query", ""))
47-
slot_millis = getattr(row_iterator, "slot_millis", 0)
45+
bytes_processed = getattr(row_iterator, "total_bytes_processed", 0) or 0
46+
query_char_count = len(getattr(row_iterator, "query", "") or "")
47+
slot_millis = getattr(row_iterator, "slot_millis", 0) or 0
4848
exec_seconds = 0.0
4949

5050
self.execution_count += 1
@@ -63,10 +63,10 @@ def count_job_stats(
6363
elif (stats := get_performance_stats(query_job)) is not None:
6464
query_char_count, bytes_processed, slot_millis, exec_seconds = stats
6565
self.execution_count += 1
66-
self.query_char_count += query_char_count
67-
self.bytes_processed += bytes_processed
68-
self.slot_millis += slot_millis
69-
self.execution_secs += exec_seconds
66+
self.query_char_count += query_char_count or 0
67+
self.bytes_processed += bytes_processed or 0
68+
self.slot_millis += slot_millis or 0
69+
self.execution_secs += exec_seconds or 0
7070
write_stats_to_disk(
7171
query_char_count=query_char_count,
7272
bytes_processed=bytes_processed,

bigframes/testing/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,11 @@ def get_function_name(func, package_requirements=None, is_row_processor=False):
440440
"""Get a bigframes function name for testing given a udf."""
441441
# Augment user package requirements with any internal package
442442
# requirements.
443-
package_requirements = bff_utils._get_updated_package_requirements(
443+
package_requirements = bff_utils.get_updated_package_requirements(
444444
package_requirements, is_row_processor
445445
)
446446

447447
# Compute a unique hash representing the user code.
448-
function_hash = bff_utils._get_hash(func, package_requirements)
448+
function_hash = bff_utils.get_hash(func, package_requirements)
449449

450450
return f"bigframes_{function_hash}"

0 commit comments

Comments
 (0)