Skip to content

Commit 504f174

Browse files
committed
feat: support series input in managed function
1 parent ebdcd02 commit 504f174

File tree

4 files changed

+258
-11
lines changed

4 files changed

+258
-11
lines changed

bigframes/functions/_function_client.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def provision_bq_managed_function(
247247
# Augment user package requirements with any internal package
248248
# requirements.
249249
packages = _utils._get_updated_package_requirements(
250-
packages, is_row_processor, capture_references
250+
packages, is_row_processor, capture_references, ignore_numpy_version=True
251251
)
252252
if packages:
253253
managed_function_options["packages"] = packages
@@ -277,7 +277,7 @@ def provision_bq_managed_function(
277277
import cloudpickle
278278

279279
pickled = cloudpickle.dumps(func)
280-
udf_code = textwrap.dedent(
280+
func_code = textwrap.dedent(
281281
f"""
282282
import cloudpickle
283283
{udf_name} = cloudpickle.loads({pickled})
@@ -287,11 +287,26 @@ def provision_bq_managed_function(
287287
# This code path ensures that if the udf body is self contained,
288288
# i.e. there are no references to variables or imports outside the
289289
# body.
290-
udf_code = textwrap.dedent(inspect.getsource(func))
291-
match = re.search(r"^def ", udf_code, flags=re.MULTILINE)
290+
func_code = textwrap.dedent(inspect.getsource(func))
291+
match = re.search(r"^def ", func_code, flags=re.MULTILINE)
292292
if match is None:
293293
raise ValueError("The UDF is not defined correctly.")
294-
udf_code = udf_code[match.start() :]
294+
func_code = func_code[match.start() :]
295+
296+
if is_row_processor:
297+
udf_code = textwrap.dedent(inspect.getsource(bff_template.get_pd_series))
298+
udf_code = udf_code[udf_code.index("def") :]
299+
bigframes_handler_code = textwrap.dedent(
300+
f"""def bigframes_handler(str_arg):
301+
return {udf_name}({bff_template.get_pd_series.__name__}(str_arg))"""
302+
)
303+
else:
304+
udf_code = ""
305+
bigframes_handler_code = textwrap.dedent(
306+
f"""def bigframes_handler(*args):
307+
return {udf_name}(*args)"""
308+
)
309+
udf_code = f"{udf_code}\n{func_code}"
295310

296311
with_connection_clause = (
297312
(
@@ -311,8 +326,7 @@ def provision_bq_managed_function(
311326
OPTIONS ({managed_function_options_str})
312327
AS r'''
313328
__UDF_PLACE_HOLDER__
314-
def bigframes_handler(*args):
315-
return {udf_name}(*args)
329+
{bigframes_handler_code}
316330
'''
317331
"""
318332
)

bigframes/functions/_function_session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,15 +847,15 @@ def wrapper(func):
847847
if output_type:
848848
py_sig = py_sig.replace(return_annotation=output_type)
849849

850-
udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)
851-
852850
# The function will actually be receiving a pandas Series, but allow
853851
# both BigQuery DataFrames and pandas object types for compatibility.
854852
is_row_processor = False
855853
if new_sig := _convert_row_processor_sig(py_sig):
856854
py_sig = new_sig
857855
is_row_processor = True
858856

857+
udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)
858+
859859
managed_function_client = _function_client.FunctionClient(
860860
dataset_ref.project,
861861
bq_location,

bigframes/functions/_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ def get_remote_function_locations(bq_location):
6161

6262

6363
def _get_updated_package_requirements(
64-
package_requirements=None, is_row_processor=False, capture_references=True
64+
package_requirements=None,
65+
is_row_processor=False,
66+
capture_references=True,
67+
ignore_numpy_version=False,
6568
):
6669
requirements = []
6770
if capture_references:
@@ -72,9 +75,16 @@ def _get_updated_package_requirements(
7275
# would be converted to a pandas series and processed Ensure numpy
7376
# versions match to avoid unpickling problems. See internal issue
7477
# b/347934471.
75-
requirements.append(f"numpy=={numpy.__version__}")
7678
requirements.append(f"pandas=={pandas.__version__}")
7779
requirements.append(f"pyarrow=={pyarrow.__version__}")
80+
# TODO(jialuo): Add back the version after b/410924784 is resolved.
81+
# Due to current limitations on the numpy version in Python UDFs, we use
82+
# `ignore_numpy_version` to optionally omit the version for managed
83+
# functions only.
84+
numpy_package = (
85+
"numpy" if ignore_numpy_version else f"numpy=={numpy.__version__}"
86+
)
87+
requirements.append(numpy_package)
7888

7989
if package_requirements:
8090
requirements.extend(package_requirements)

tests/system/large/functions/test_managed_function.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,226 @@ def foo(x: int) -> int:
647647
container_cpu=2,
648648
container_memory="64Mi",
649649
)(foo)
650+
651+
652+
def test_managed_function_df_apply_axis_1(session, dataset_id, scalars_dfs):
653+
columns = ["bool_col", "int64_col", "int64_too", "float64_col", "string_col"]
654+
scalars_df, scalars_pandas_df = scalars_dfs
655+
try:
656+
657+
def serialize_row(row):
658+
# Explicitly casting types ensures consistent behavior between
659+
# BigFrames and pandas. Without it, BigFrames return plain Python
660+
# types, e.g. 0, while pandas return NumPy types, e.g. np.int64(0),
661+
# which could lead to mismatches and requires further investigation.
662+
custom = {
663+
"name": int(row.name),
664+
"index": [idx for idx in row.index],
665+
"values": [
666+
val.item() if hasattr(val, "item") else val for val in row.values
667+
],
668+
}
669+
670+
return str(
671+
{
672+
"default": row.to_json(),
673+
"split": row.to_json(orient="split"),
674+
"records": row.to_json(orient="records"),
675+
"index": row.to_json(orient="index"),
676+
"table": row.to_json(orient="table"),
677+
"custom": custom,
678+
}
679+
)
680+
681+
serialize_row_mf = session.udf(
682+
input_types=bigframes.series.Series,
683+
output_type=str,
684+
dataset=dataset_id,
685+
name=prefixer.create_prefix(),
686+
)(serialize_row)
687+
688+
assert getattr(serialize_row_mf, "is_row_processor")
689+
690+
bf_result = scalars_df[columns].apply(serialize_row_mf, axis=1).to_pandas()
691+
pd_result = scalars_pandas_df[columns].apply(serialize_row, axis=1)
692+
693+
# bf_result.dtype is 'string[pyarrow]' while pd_result.dtype is 'object'
694+
# , ignore this mismatch by using check_dtype=False.
695+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
696+
697+
# Let's make sure the read_gbq_function path works for this function.
698+
serialize_row_reuse = session.read_gbq_function(
699+
serialize_row_mf.bigframes_bigquery_function, is_row_processor=True
700+
)
701+
bf_result = scalars_df[columns].apply(serialize_row_reuse, axis=1).to_pandas()
702+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
703+
704+
finally:
705+
# clean up the gcp assets created for the managed function.
706+
cleanup_function_assets(
707+
serialize_row_mf, session.bqclient, session.cloudfunctionsclient
708+
)
709+
710+
711+
def test_managed_function_df_apply_axis_1_aggregates(session, dataset_id, scalars_dfs):
712+
columns = ["int64_col", "int64_too", "float64_col"]
713+
scalars_df, scalars_pandas_df = scalars_dfs
714+
715+
try:
716+
717+
def analyze(row):
718+
# Explicitly casting types ensures consistent behavior between
719+
# BigFrames and pandas. Without it, BigFrames return plain Python
720+
# types, e.g. 0, while pandas return NumPy types, e.g. np.int64(0),
721+
# which could lead to mismatches and requires further investigation.
722+
return str(
723+
{
724+
"dtype": row.dtype,
725+
"count": int(row.count()),
726+
"min": int(row.min()),
727+
"max": int(row.max()),
728+
"mean": float(row.mean()),
729+
"std": float(row.std()),
730+
"var": float(row.var()),
731+
}
732+
)
733+
734+
analyze_mf = session.udf(
735+
input_types=bigframes.series.Series,
736+
output_type=str,
737+
dataset=dataset_id,
738+
name=prefixer.create_prefix(),
739+
)(analyze)
740+
741+
assert getattr(analyze_mf, "is_row_processor")
742+
743+
bf_result = scalars_df[columns].dropna().apply(analyze_mf, axis=1).to_pandas()
744+
pd_result = scalars_pandas_df[columns].dropna().apply(analyze, axis=1)
745+
746+
# bf_result.dtype is 'string[pyarrow]' while pd_result.dtype is 'object'
747+
# , ignore this mismatch by using check_dtype=False.
748+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
749+
750+
finally:
751+
# clean up the gcp assets created for the managed function.
752+
cleanup_function_assets(
753+
analyze_mf, session.bqclient, session.cloudfunctionsclient
754+
)
755+
756+
757+
@pytest.mark.parametrize(
758+
("pd_df",),
759+
[
760+
pytest.param(
761+
pandas.DataFrame(
762+
{
763+
"2": [1, 2, 3],
764+
2: [1.5, 3.75, 5],
765+
"name, [with. special'- chars\")/\\": [10, 20, 30],
766+
(3, 4): ["pq", "rs", "tu"],
767+
(5.0, "six", 7): [8, 9, 10],
768+
'raise Exception("hacked!")': [11, 12, 13],
769+
},
770+
# Default pandas index has non-numpy type, whereas bigframes is
771+
# always numpy-based type, so let's use the index compatible
772+
# with bigframes. See more details in b/369689696.
773+
index=pandas.Index([0, 1, 2], dtype=pandas.Int64Dtype()),
774+
),
775+
id="all-kinds-of-column-names",
776+
),
777+
pytest.param(
778+
pandas.DataFrame(
779+
{
780+
"x": [1, 2, 3],
781+
"y": [1.5, 3.75, 5],
782+
"z": ["pq", "rs", "tu"],
783+
},
784+
index=pandas.MultiIndex.from_frame(
785+
pandas.DataFrame(
786+
{
787+
"idx0": pandas.Series(
788+
["a", "a", "b"], dtype=pandas.StringDtype()
789+
),
790+
"idx1": pandas.Series(
791+
[100, 200, 300], dtype=pandas.Int64Dtype()
792+
),
793+
}
794+
)
795+
),
796+
),
797+
id="multiindex",
798+
marks=pytest.mark.skip(
799+
reason="TODO: revert this skip after this pandas bug is fixed: https://github.com/pandas-dev/pandas/issues/59908"
800+
),
801+
),
802+
pytest.param(
803+
pandas.DataFrame(
804+
[
805+
[10, 1.5, "pq"],
806+
[20, 3.75, "rs"],
807+
[30, 8.0, "tu"],
808+
],
809+
# Default pandas index has non-numpy type, whereas bigframes is
810+
# always numpy-based type, so let's use the index compatible
811+
# with bigframes. See more details in b/369689696.
812+
index=pandas.Index([0, 1, 2], dtype=pandas.Int64Dtype()),
813+
columns=pandas.MultiIndex.from_arrays(
814+
[
815+
["first", "last_two", "last_two"],
816+
[1, 2, 3],
817+
]
818+
),
819+
),
820+
id="column-multiindex",
821+
),
822+
],
823+
)
824+
def test_managed_function_df_apply_axis_1_complex(session, dataset_id, pd_df):
825+
bf_df = session.read_pandas(pd_df)
826+
827+
try:
828+
829+
def serialize_row(row):
830+
# Explicitly casting types ensures consistent behavior between
831+
# BigFrames and pandas. Without it, BigFrames return plain Python
832+
# types, e.g. 0, while pandas return NumPy types, e.g. np.int64(0),
833+
# which could lead to mismatches and requires further investigation.
834+
custom = {
835+
"name": int(row.name),
836+
"index": [idx for idx in row.index],
837+
"values": [
838+
val.item() if hasattr(val, "item") else val for val in row.values
839+
],
840+
}
841+
return str(
842+
{
843+
"default": row.to_json(),
844+
"split": row.to_json(orient="split"),
845+
"records": row.to_json(orient="records"),
846+
"index": row.to_json(orient="index"),
847+
"custom": custom,
848+
}
849+
)
850+
851+
serialize_row_mf = session.udf(
852+
input_types=bigframes.series.Series,
853+
output_type=str,
854+
dataset=dataset_id,
855+
name=prefixer.create_prefix(),
856+
)(serialize_row)
857+
858+
assert getattr(serialize_row_mf, "is_row_processor")
859+
860+
bf_result = bf_df.apply(serialize_row_mf, axis=1).to_pandas()
861+
pd_result = pd_df.apply(serialize_row, axis=1)
862+
863+
# ignore known dtype difference between pandas and bigframes.
864+
pandas.testing.assert_series_equal(
865+
pd_result, bf_result, check_dtype=False, check_index_type=False
866+
)
867+
868+
finally:
869+
# clean up the gcp assets created for the managed function.
870+
cleanup_function_assets(
871+
serialize_row_mf, session.bqclient, session.cloudfunctionsclient
872+
)

0 commit comments

Comments
 (0)