Skip to content

Commit 4688d70

Browse files
committed
resolve the comments
1 parent 504f174 commit 4688d70

File tree

5 files changed

+158
-49
lines changed

5 files changed

+158
-49
lines changed

bigframes/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ class FunctionAxisOnePreviewWarning(PreviewWarning):
103103
"""Remote Function and Managed UDF with axis=1 preview."""
104104

105105

106+
class FunctionPackageVersionWarning(PreviewWarning):
107+
"""
108+
Managed UDF package versions may not precisely match users' local
109+
environment or the exact versions specified.
110+
"""
111+
112+
106113
def format_message(message: str, fill: bool = True):
107114
"""Formats a warning message with ANSI color codes for the warning color.
108115

bigframes/functions/_function_client.py

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import logging
2020
import os
2121
import random
22-
import re
2322
import shutil
2423
import string
2524
import tempfile
@@ -270,43 +269,6 @@ def provision_bq_managed_function(
270269
)
271270

272271
udf_name = func.__name__
273-
if capture_references:
274-
# This code path ensures that if the udf body contains any
275-
# references to variables and/or imports outside the body, they are
276-
# captured as well.
277-
import cloudpickle
278-
279-
pickled = cloudpickle.dumps(func)
280-
func_code = textwrap.dedent(
281-
f"""
282-
import cloudpickle
283-
{udf_name} = cloudpickle.loads({pickled})
284-
"""
285-
)
286-
else:
287-
# This code path ensures that if the udf body is self contained,
288-
# i.e. there are no references to variables or imports outside the
289-
# body.
290-
func_code = textwrap.dedent(inspect.getsource(func))
291-
match = re.search(r"^def ", func_code, flags=re.MULTILINE)
292-
if match is None:
293-
raise ValueError("The UDF is not defined correctly.")
294-
func_code = func_code[match.start() :]
295-
296-
if is_row_processor:
297-
udf_code = textwrap.dedent(inspect.getsource(bff_template.get_pd_series))
298-
udf_code = udf_code[udf_code.index("def") :]
299-
bigframes_handler_code = textwrap.dedent(
300-
f"""def bigframes_handler(str_arg):
301-
return {udf_name}({bff_template.get_pd_series.__name__}(str_arg))"""
302-
)
303-
else:
304-
udf_code = ""
305-
bigframes_handler_code = textwrap.dedent(
306-
f"""def bigframes_handler(*args):
307-
return {udf_name}(*args)"""
308-
)
309-
udf_code = f"{udf_code}\n{func_code}"
310272

311273
with_connection_clause = (
312274
(
@@ -316,6 +278,13 @@ def provision_bq_managed_function(
316278
else ""
317279
)
318280

281+
# Generate the complete Python code block for the managed Python UDF,
282+
# including the user's function, necessary imports, and the BigQuery
283+
# handler wrapper.
284+
python_code_block = bff_template.generate_managed_function_code(
285+
func, udf_name, is_row_processor, capture_references
286+
)
287+
319288
create_function_ddl = (
320289
textwrap.dedent(
321290
f"""
@@ -326,12 +295,11 @@ def provision_bq_managed_function(
326295
OPTIONS ({managed_function_options_str})
327296
AS r'''
328297
__UDF_PLACE_HOLDER__
329-
{bigframes_handler_code}
330298
'''
331299
"""
332300
)
333301
.strip()
334-
.replace("__UDF_PLACE_HOLDER__", udf_code)
302+
.replace("__UDF_PLACE_HOLDER__", python_code_block)
335303
)
336304

337305
self._ensure_dataset_exists()

bigframes/functions/_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import sys
1919
import typing
2020
from typing import cast, Optional, Set
21+
import warnings
2122

2223
import cloudpickle
2324
import google.api_core.exceptions
@@ -26,6 +27,7 @@
2627
import pandas
2728
import pyarrow
2829

30+
import bigframes.exceptions as bfe
2931
import bigframes.formatting_helpers as bf_formatting
3032
from bigframes.functions import function_typing
3133

@@ -81,9 +83,14 @@ def _get_updated_package_requirements(
8183
# Due to current limitations on the numpy version in Python UDFs, we use
8284
# `ignore_numpy_version` to optionally omit the version for managed
8385
# functions only.
84-
numpy_package = (
85-
"numpy" if ignore_numpy_version else f"numpy=={numpy.__version__}"
86-
)
86+
if ignore_numpy_version:
87+
msg = bfe.format_message(
88+
"Numpy version may not precisely match your local environment."
89+
)
90+
warnings.warn(msg, category=bfe.PreviewWarning)
91+
numpy_package = "numpy"
92+
else:
93+
numpy_package = f"numpy=={numpy.__version__}"
8794
requirements.append(numpy_package)
8895

8996
if package_requirements:

bigframes/functions/function_template.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import inspect
1818
import logging
1919
import os
20+
import re
2021
import textwrap
2122
from typing import Tuple
2223

@@ -291,3 +292,55 @@ def generate_cloud_function_main_code(
291292
logger.debug(f"Wrote {os.path.abspath(main_py)}:\n{open(main_py).read()}")
292293

293294
return handler_func_name
295+
296+
297+
def generate_managed_function_code(
298+
def_,
299+
udf_name: str,
300+
is_row_processor: bool,
301+
capture_references: bool,
302+
) -> str:
303+
"""Generates the Python code block for managed Python UDF."""
304+
305+
if capture_references:
306+
# This code path ensures that if the udf body contains any
307+
# references to variables and/or imports outside the body, they are
308+
# captured as well.
309+
import cloudpickle
310+
311+
pickled = cloudpickle.dumps(def_)
312+
func_code = textwrap.dedent(
313+
f"""
314+
import cloudpickle
315+
{udf_name} = cloudpickle.loads({pickled})
316+
"""
317+
)
318+
else:
319+
# This code path ensures that if the udf body is self contained,
320+
# i.e. there are no references to variables or imports outside the
321+
# body.
322+
func_code = textwrap.dedent(inspect.getsource(def_))
323+
match = re.search(r"^def ", func_code, flags=re.MULTILINE)
324+
if match is None:
325+
raise ValueError("The UDF is not defined correctly.")
326+
func_code = func_code[match.start() :]
327+
328+
if is_row_processor:
329+
udf_code = textwrap.dedent(inspect.getsource(get_pd_series))
330+
udf_code = udf_code[udf_code.index("def") :]
331+
bigframes_handler_code = textwrap.dedent(
332+
f"""def bigframes_handler(str_arg):
333+
return {udf_name}({get_pd_series.__name__}(str_arg))"""
334+
)
335+
else:
336+
udf_code = ""
337+
bigframes_handler_code = textwrap.dedent(
338+
f"""def bigframes_handler(*args):
339+
return {udf_name}(*args)"""
340+
)
341+
342+
udf_code_block = textwrap.dedent(
343+
f"{udf_code}\n{func_code}\n{bigframes_handler_code}"
344+
)
345+
346+
return udf_code_block

tests/system/large/functions/test_managed_function.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ def serialize_row(row):
659659
# BigFrames and pandas. Without it, BigFrames return plain Python
660660
# types, e.g. 0, while pandas return NumPy types, e.g. np.int64(0),
661661
# which could lead to mismatches and requires further investigation.
662+
# See b/435021126.
662663
custom = {
663664
"name": int(row.name),
664665
"index": [idx for idx in row.index],
@@ -719,6 +720,7 @@ def analyze(row):
719720
# BigFrames and pandas. Without it, BigFrames return plain Python
720721
# types, e.g. 0, while pandas return NumPy types, e.g. np.int64(0),
721722
# which could lead to mismatches and requires further investigation.
723+
# See b/435021126.
722724
return str(
723725
{
724726
"dtype": row.dtype,
@@ -731,12 +733,17 @@ def analyze(row):
731733
}
732734
)
733735

734-
analyze_mf = session.udf(
735-
input_types=bigframes.series.Series,
736-
output_type=str,
737-
dataset=dataset_id,
738-
name=prefixer.create_prefix(),
739-
)(analyze)
736+
with pytest.warns(
737+
bfe.PreviewWarning,
738+
match=("Numpy version may not precisely match your local environment."),
739+
):
740+
741+
analyze_mf = session.udf(
742+
input_types=bigframes.series.Series,
743+
output_type=str,
744+
dataset=dataset_id,
745+
name=prefixer.create_prefix(),
746+
)(analyze)
740747

741748
assert getattr(analyze_mf, "is_row_processor")
742749

@@ -831,6 +838,7 @@ def serialize_row(row):
831838
# BigFrames and pandas. Without it, BigFrames return plain Python
832839
# types, e.g. 0, while pandas return NumPy types, e.g. np.int64(0),
833840
# which could lead to mismatches and requires further investigation.
841+
# See b/435021126.
834842
custom = {
835843
"name": int(row.name),
836844
"index": [idx for idx in row.index],
@@ -870,3 +878,69 @@ def serialize_row(row):
870878
cleanup_function_assets(
871879
serialize_row_mf, session.bqclient, session.cloudfunctionsclient
872880
)
881+
882+
883+
@pytest.mark.skip(reason="Revert after this bug b/435018880 is fixed.")
884+
def test_managed_function_df_apply_axis_1_na_nan_inf(dataset_id, session):
885+
"""This test is for special cases of float values, to make sure any (nan,
886+
inf, -inf) produced by user code is honored.
887+
"""
888+
bf_df = session.read_gbq(
889+
"""\
890+
SELECT "1" AS text, 1 AS num
891+
UNION ALL
892+
SELECT "2.5" AS text, 2.5 AS num
893+
UNION ALL
894+
SELECT "nan" AS text, IEEE_DIVIDE(0, 0) AS num
895+
UNION ALL
896+
SELECT "inf" AS text, IEEE_DIVIDE(1, 0) AS num
897+
UNION ALL
898+
SELECT "-inf" AS text, IEEE_DIVIDE(-1, 0) AS num
899+
UNION ALL
900+
SELECT "numpy nan" AS text, IEEE_DIVIDE(0, 0) AS num
901+
UNION ALL
902+
SELECT "pandas na" AS text, NULL AS num
903+
"""
904+
)
905+
906+
pd_df = bf_df.to_pandas()
907+
908+
try:
909+
910+
def float_parser(row):
911+
import numpy as mynp
912+
import pandas as mypd
913+
914+
if row["text"] == "pandas na":
915+
return mypd.NA
916+
if row["text"] == "numpy nan":
917+
return mynp.nan
918+
return float(row["text"])
919+
920+
float_parser_mf = session.udf(
921+
input_types=bigframes.series.Series,
922+
output_type=float,
923+
dataset=dataset_id,
924+
name=prefixer.create_prefix(),
925+
)(float_parser)
926+
927+
assert getattr(float_parser_mf, "is_row_processor")
928+
929+
pd_result = pd_df.apply(float_parser, axis=1)
930+
bf_result = bf_df.apply(float_parser_mf, axis=1).to_pandas()
931+
932+
# bf_result.dtype is 'Float64' while pd_result.dtype is 'object'
933+
# , ignore this mismatch by using check_dtype=False.
934+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
935+
936+
# Let's also assert that the data is consistent in this round trip
937+
# (BQ -> BigFrames -> BQ -> GCF -> BQ -> BigFrames) w.r.t. their
938+
# expected values in BQ.
939+
bq_result = bf_df["num"].to_pandas()
940+
bq_result.name = None
941+
pandas.testing.assert_series_equal(bq_result, bf_result)
942+
finally:
943+
# clean up the gcp assets created for the managed function.
944+
cleanup_function_assets(
945+
float_parser_mf, session.bqclient, session.cloudfunctionsclient
946+
)

0 commit comments

Comments
 (0)