Skip to content

Commit 05177d6

Browse files
authored
Merge branch 'main' into udf-type
2 parents 5fa3c98 + 62a189f commit 05177d6

File tree

13 files changed

+443
-86
lines changed

13 files changed

+443
-86
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def _aggregate(
646646
def compile_explode(self, node: nodes.ExplodeNode):
647647
assert node.offsets_col is None
648648
df = self.compile_node(node.child)
649-
cols = [pl.col(col.id.sql) for col in node.column_ids]
649+
cols = [col.id.sql for col in node.column_ids]
650650
return df.explode(cols)
651651

652652
@compile_node.register

bigframes/core/indexes/base.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,12 @@
2727
import pandas
2828

2929
from bigframes import dtypes
30-
from bigframes.core.array_value import ArrayValue
3130
import bigframes.core.block_transforms as block_ops
3231
import bigframes.core.blocks as blocks
3332
import bigframes.core.expression as ex
34-
import bigframes.core.identifiers as ids
35-
import bigframes.core.nodes as nodes
3633
import bigframes.core.ordering as order
3734
import bigframes.core.utils as utils
3835
import bigframes.core.validations as validations
39-
import bigframes.core.window_spec as window_spec
4036
import bigframes.dtypes
4137
import bigframes.formatting_helpers as formatter
4238
import bigframes.operations as ops
@@ -272,37 +268,20 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
272268
# Get the index column from the block
273269
index_column = self._block.index_columns[0]
274270

275-
# Apply row numbering to the original data
276-
row_number_column_id = ids.ColumnId.unique()
277-
window_node = nodes.WindowOpNode(
278-
child=self._block._expr.node,
279-
expression=ex.NullaryAggregation(agg_ops.RowNumberOp()),
280-
window_spec=window_spec.unbound(),
281-
output_name=row_number_column_id,
282-
never_skip_nulls=True,
283-
)
284-
285-
windowed_array = ArrayValue(window_node)
286-
windowed_block = blocks.Block(
287-
windowed_array,
288-
index_columns=self._block.index_columns,
289-
column_labels=self._block.column_labels.insert(
290-
len(self._block.column_labels), None
291-
),
292-
index_labels=self._block._index_labels,
271+
# Use promote_offsets to get row numbers (similar to argmax/argmin implementation)
272+
block_with_offsets, offsets_id = self._block.promote_offsets(
273+
"temp_get_loc_offsets_"
293274
)
294275

295276
# Create expression to find matching positions
296277
match_expr = ops.eq_op.as_expr(ex.deref(index_column), ex.const(key))
297-
windowed_block, match_col_id = windowed_block.project_expr(match_expr)
278+
block_with_offsets, match_col_id = block_with_offsets.project_expr(match_expr)
298279

299280
# Filter to only rows where the key matches
300-
filtered_block = windowed_block.filter_by_id(match_col_id)
281+
filtered_block = block_with_offsets.filter_by_id(match_col_id)
301282

302-
# Check if key exists at all by counting on the filtered block
303-
count_agg = ex.UnaryAggregation(
304-
agg_ops.count_op, ex.deref(row_number_column_id.name)
305-
)
283+
# Check if key exists at all by counting
284+
count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(offsets_id))
306285
count_result = filtered_block._expr.aggregate([(count_agg, "count")])
307286
count_scalar = self._block.session._executor.execute(
308287
count_result
@@ -313,9 +292,7 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
313292

314293
# If only one match, return integer position
315294
if count_scalar == 1:
316-
min_agg = ex.UnaryAggregation(
317-
agg_ops.min_op, ex.deref(row_number_column_id.name)
318-
)
295+
min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id))
319296
position_result = filtered_block._expr.aggregate([(min_agg, "position")])
320297
position_scalar = self._block.session._executor.execute(
321298
position_result
@@ -325,32 +302,24 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
325302
# Handle multiple matches based on index monotonicity
326303
is_monotonic = self.is_monotonic_increasing or self.is_monotonic_decreasing
327304
if is_monotonic:
328-
return self._get_monotonic_slice(filtered_block, row_number_column_id)
305+
return self._get_monotonic_slice(filtered_block, offsets_id)
329306
else:
330307
# Return boolean mask for non-monotonic duplicates
331-
mask_block = windowed_block.select_columns([match_col_id])
332-
# Reset the index to use positional integers instead of original index values
308+
mask_block = block_with_offsets.select_columns([match_col_id])
333309
mask_block = mask_block.reset_index(drop=True)
334-
# Ensure correct dtype and name to match pandas behavior
335310
result_series = bigframes.series.Series(mask_block)
336311
return result_series.astype("boolean")
337312

338-
def _get_monotonic_slice(
339-
self, filtered_block, row_number_column_id: "ids.ColumnId"
340-
) -> slice:
313+
def _get_monotonic_slice(self, filtered_block, offsets_id: str) -> slice:
341314
"""Helper method to get a slice for monotonic duplicates with an optimized query."""
342315
# Combine min and max aggregations into a single query for efficiency
343316
min_max_aggs = [
344317
(
345-
ex.UnaryAggregation(
346-
agg_ops.min_op, ex.deref(row_number_column_id.name)
347-
),
318+
ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)),
348319
"min_pos",
349320
),
350321
(
351-
ex.UnaryAggregation(
352-
agg_ops.max_op, ex.deref(row_number_column_id.name)
353-
),
322+
ex.UnaryAggregation(agg_ops.max_op, ex.deref(offsets_id)),
354323
"max_pos",
355324
),
356325
]

bigframes/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ class FunctionAxisOnePreviewWarning(PreviewWarning):
107107
"""Remote Function and Managed UDF with axis=1 preview."""
108108

109109

110+
class FunctionPackageVersionWarning(PreviewWarning):
111+
"""
112+
Managed UDF package versions for Numpy, Pandas, and Pyarrow may not
113+
precisely match users' local environment or the exact versions specified.
114+
"""
115+
116+
110117
def format_message(message: str, fill: bool = True):
111118
"""Formats a warning message with ANSI color codes for the warning color.
112119

bigframes/functions/_function_client.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import logging
2020
import os
2121
import random
22-
import re
2322
import shutil
2423
import string
2524
import tempfile
@@ -247,7 +246,7 @@ def provision_bq_managed_function(
247246
# Augment user package requirements with any internal package
248247
# requirements.
249248
packages = _utils._get_updated_package_requirements(
250-
packages, is_row_processor, capture_references
249+
packages, is_row_processor, capture_references, ignore_package_version=True
251250
)
252251
if packages:
253252
managed_function_options["packages"] = packages
@@ -270,28 +269,6 @@ def provision_bq_managed_function(
270269
)
271270

272271
udf_name = func.__name__
273-
if capture_references:
274-
# This code path ensures that if the udf body contains any
275-
# references to variables and/or imports outside the body, they are
276-
# captured as well.
277-
import cloudpickle
278-
279-
pickled = cloudpickle.dumps(func)
280-
udf_code = textwrap.dedent(
281-
f"""
282-
import cloudpickle
283-
{udf_name} = cloudpickle.loads({pickled})
284-
"""
285-
)
286-
else:
287-
# This code path ensures that if the udf body is self contained,
288-
# i.e. there are no references to variables or imports outside the
289-
# body.
290-
udf_code = textwrap.dedent(inspect.getsource(func))
291-
match = re.search(r"^def ", udf_code, flags=re.MULTILINE)
292-
if match is None:
293-
raise ValueError("The UDF is not defined correctly.")
294-
udf_code = udf_code[match.start() :]
295272

296273
with_connection_clause = (
297274
(
@@ -301,6 +278,13 @@ def provision_bq_managed_function(
301278
else ""
302279
)
303280

281+
# Generate the complete Python code block for the managed Python UDF,
282+
# including the user's function, necessary imports, and the BigQuery
283+
# handler wrapper.
284+
python_code_block = bff_template.generate_managed_function_code(
285+
func, udf_name, is_row_processor, capture_references
286+
)
287+
304288
create_function_ddl = (
305289
textwrap.dedent(
306290
f"""
@@ -311,13 +295,11 @@ def provision_bq_managed_function(
311295
OPTIONS ({managed_function_options_str})
312296
AS r'''
313297
__UDF_PLACE_HOLDER__
314-
def bigframes_handler(*args):
315-
return {udf_name}(*args)
316298
'''
317299
"""
318300
)
319301
.strip()
320-
.replace("__UDF_PLACE_HOLDER__", udf_code)
302+
.replace("__UDF_PLACE_HOLDER__", python_code_block)
321303
)
322304

323305
self._ensure_dataset_exists()

bigframes/functions/_function_session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,15 +867,15 @@ def wrapper(func):
867867
warnings.warn(msg, category=bfe.FunctionRedundantTypeHintWarning)
868868
py_sig = py_sig.replace(return_annotation=output_type)
869869

870-
udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)
871-
872870
# The function will actually be receiving a pandas Series, but allow
873871
# both BigQuery DataFrames and pandas object types for compatibility.
874872
is_row_processor = False
875873
if new_sig := _convert_row_processor_sig(py_sig):
876874
py_sig = new_sig
877875
is_row_processor = True
878876

877+
udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)
878+
879879
managed_function_client = _function_client.FunctionClient(
880880
dataset_ref.project,
881881
bq_location,

bigframes/functions/_utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import sys
2020
import typing
2121
from typing import cast, Optional, Set
22+
import warnings
2223

2324
import cloudpickle
2425
import google.api_core.exceptions
@@ -27,6 +28,7 @@
2728
import pandas
2829
import pyarrow
2930

31+
import bigframes.exceptions as bfe
3032
import bigframes.formatting_helpers as bf_formatting
3133
from bigframes.functions import function_typing
3234

@@ -62,21 +64,40 @@ def get_remote_function_locations(bq_location):
6264

6365

6466
def _get_updated_package_requirements(
65-
package_requirements=None, is_row_processor=False, capture_references=True
67+
package_requirements=None,
68+
is_row_processor=False,
69+
capture_references=True,
70+
ignore_package_version=False,
6671
):
6772
requirements = []
6873
if capture_references:
6974
requirements.append(f"cloudpickle=={cloudpickle.__version__}")
7075

7176
if is_row_processor:
72-
# bigframes function will send an entire row of data as json, which
73-
# would be converted to a pandas series and processed Ensure numpy
74-
# versions match to avoid unpickling problems. See internal issue
75-
# b/347934471.
76-
requirements.append(f"numpy=={numpy.__version__}")
77-
requirements.append(f"pandas=={pandas.__version__}")
78-
requirements.append(f"pyarrow=={pyarrow.__version__}")
79-
77+
if ignore_package_version:
78+
# TODO(jialuo): Add back the version after b/410924784 is resolved.
79+
# Due to current limitations on the packages version in Python UDFs,
80+
# we use `ignore_package_version` to optionally omit the version for
81+
# managed functions only.
82+
msg = bfe.format_message(
83+
"numpy, pandas, and pyarrow versions in the function execution"
84+
" environment may not precisely match your local environment."
85+
)
86+
warnings.warn(msg, category=bfe.FunctionPackageVersionWarning)
87+
requirements.append("pandas")
88+
requirements.append("pyarrow")
89+
requirements.append("numpy")
90+
else:
91+
# bigframes function will send an entire row of data as json, which
92+
# would be converted to a pandas series and processed Ensure numpy
93+
# versions match to avoid unpickling problems. See internal issue
94+
# b/347934471.
95+
requirements.append(f"pandas=={pandas.__version__}")
96+
requirements.append(f"pyarrow=={pyarrow.__version__}")
97+
requirements.append(f"numpy=={numpy.__version__}")
98+
99+
# TODO(b/435023957): Fix the issue of potential duplicate package versions
100+
# when `package_requirements` also contains `pandas/pyarrow/numpy`.
80101
if package_requirements:
81102
requirements.extend(package_requirements)
82103

bigframes/functions/function_template.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import inspect
1818
import logging
1919
import os
20+
import re
2021
import textwrap
2122
from typing import Tuple
2223

@@ -291,3 +292,55 @@ def generate_cloud_function_main_code(
291292
logger.debug(f"Wrote {os.path.abspath(main_py)}:\n{open(main_py).read()}")
292293

293294
return handler_func_name
295+
296+
297+
def generate_managed_function_code(
298+
def_,
299+
udf_name: str,
300+
is_row_processor: bool,
301+
capture_references: bool,
302+
) -> str:
303+
"""Generates the Python code block for managed Python UDF."""
304+
305+
if capture_references:
306+
# This code path ensures that if the udf body contains any
307+
# references to variables and/or imports outside the body, they are
308+
# captured as well.
309+
import cloudpickle
310+
311+
pickled = cloudpickle.dumps(def_)
312+
func_code = textwrap.dedent(
313+
f"""
314+
import cloudpickle
315+
{udf_name} = cloudpickle.loads({pickled})
316+
"""
317+
)
318+
else:
319+
# This code path ensures that if the udf body is self contained,
320+
# i.e. there are no references to variables or imports outside the
321+
# body.
322+
func_code = textwrap.dedent(inspect.getsource(def_))
323+
match = re.search(r"^def ", func_code, flags=re.MULTILINE)
324+
if match is None:
325+
raise ValueError("The UDF is not defined correctly.")
326+
func_code = func_code[match.start() :]
327+
328+
if is_row_processor:
329+
udf_code = textwrap.dedent(inspect.getsource(get_pd_series))
330+
udf_code = udf_code[udf_code.index("def") :]
331+
bigframes_handler_code = textwrap.dedent(
332+
f"""def bigframes_handler(str_arg):
333+
return {udf_name}({get_pd_series.__name__}(str_arg))"""
334+
)
335+
else:
336+
udf_code = ""
337+
bigframes_handler_code = textwrap.dedent(
338+
f"""def bigframes_handler(*args):
339+
return {udf_name}(*args)"""
340+
)
341+
342+
udf_code_block = textwrap.dedent(
343+
f"{udf_code}\n{func_code}\n{bigframes_handler_code}"
344+
)
345+
346+
return udf_code_block

bigframes/functions/function_typing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def __init__(self, type_, supported_types):
6161
self.type = type_
6262
self.supported_types = supported_types
6363
super().__init__(
64-
f"'{type_}' is not one of the supported types {supported_types}"
64+
f"'{type_}' must be one of the supported types ({supported_types}) "
65+
"or a list of one of those types."
6566
)
6667

6768

0 commit comments

Comments
 (0)