Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion bigframes/functions/_function_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,17 @@ def _convert_row_processor_sig(
if len(signature.parameters) >= 1:
first_param = next(iter(signature.parameters.values()))
param_type = first_param.annotation
if (param_type == bf_series.Series) or (param_type == pandas.Series):
# Type hints for Series inputs should use pandas.Series because the
# underlying serialization process converts the input to a string
# representation of a pandas Series (not bigframes Series). Using
# bigframes Series will lead to TypeError when creating the function
# remotely. See more from b/445182819.
if param_type == bf_series.Series:
raise bf_formatting.create_exception_with_feedback_link(
TypeError,
"Argument type hint must be Pandas Series, not BigFrames Series.",
)
if param_type == pandas.Series:
msg = bfe.format_message("input_types=Series is in preview.")
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
return signature.replace(
Expand Down
18 changes: 13 additions & 5 deletions bigframes/functions/function_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,16 @@ def generate_managed_function_code(
return {udf_name}(*args)"""
)

udf_code_block = textwrap.dedent(
f"{udf_code}\n{func_code}\n{bigframes_handler_code}"
)

return udf_code_block
udf_code_block = []
if not capture_references and is_row_processor:
# Enable postponed evaluation of type annotations. This converts all
# type hints to strings at runtime, which is necessary for correctly
# handling the type annotation of pandas.Series after the UDF code is
# serialized for remote execution. See more from b/445182819.
udf_code_block.append("from __future__ import annotations")

udf_code_block.append(udf_code)
udf_code_block.append(func_code)
udf_code_block.append(bigframes_handler_code)

return textwrap.dedent("\n".join(udf_code_block))
3 changes: 2 additions & 1 deletion bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,8 +2064,9 @@ def read_gbq_function(
note, row processor implies that the function has only one input
parameter.

>>> import pandas as pd
>>> @bpd.remote_function(cloud_function_service_account="default")
... def row_sum(s: bpd.Series) -> float:
... def row_sum(s: pd.Series) -> float:
... return s['a'] + s['b'] + s['c']

>>> row_sum_ref = bpd.read_gbq_function(
Expand Down
31 changes: 19 additions & 12 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,19 @@ def serialize_row(row):
}
)

with pytest.raises(
TypeError,
match="Argument type hint must be Pandas Series, not BigFrames Series.",
):
serialize_row_mf = session.udf(
input_types=bigframes.series.Series,
output_type=str,
dataset=dataset_id,
name=prefixer.create_prefix(),
)(serialize_row)

serialize_row_mf = session.udf(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=str,
dataset=dataset_id,
name=prefixer.create_prefix(),
Expand Down Expand Up @@ -762,7 +773,7 @@ def analyze(row):
):

analyze_mf = session.udf(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=str,
dataset=dataset_id,
name=prefixer.create_prefix(),
Expand Down Expand Up @@ -876,7 +887,7 @@ def serialize_row(row):
)

serialize_row_mf = session.udf(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=str,
dataset=dataset_id,
name=prefixer.create_prefix(),
Expand Down Expand Up @@ -926,7 +937,7 @@ def test_managed_function_df_apply_axis_1_na_nan_inf(dataset_id, session):

try:

def float_parser(row):
def float_parser(row: pandas.Series):
import numpy as mynp
import pandas as mypd

Expand All @@ -937,7 +948,7 @@ def float_parser(row):
return float(row["text"])

float_parser_mf = session.udf(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=float,
dataset=dataset_id,
name=prefixer.create_prefix(),
Expand Down Expand Up @@ -1027,7 +1038,7 @@ def test_managed_function_df_apply_axis_1_series_args(session, dataset_id, scala

try:

def analyze(s, x, y):
def analyze(s: pandas.Series, x: bool, y: float) -> str:
value = f"value is {s['int64_col']} and {s['float64_col']}"
if x:
return f"{value}, x is True!"
Expand All @@ -1036,8 +1047,6 @@ def analyze(s, x, y):
return f"{value}, x is False, y is non-positive!"

analyze_mf = session.udf(
input_types=[bigframes.series.Series, bool, float],
output_type=str,
dataset=dataset_id,
name=prefixer.create_prefix(),
)(analyze)
Expand Down Expand Up @@ -1151,7 +1160,7 @@ def is_sum_positive_series(s):
return s["int64_col"] + s["int64_too"] > 0

is_sum_positive_series_mf = session.udf(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=bool,
dataset=dataset_id,
name=prefixer.create_prefix(),
Expand Down Expand Up @@ -1217,12 +1226,10 @@ def func_for_other(x):
def test_managed_function_df_where_other_issue(session, dataset_id, scalars_df_index):
try:

def the_sum(s):
def the_sum(s: pandas.Series) -> int:
return s["int64_col"] + s["int64_too"]

the_sum_mf = session.udf(
input_types=bigframes.series.Series,
output_type=int,
dataset=dataset_id,
name=prefixer.create_prefix(),
)(the_sum)
Expand Down
28 changes: 18 additions & 10 deletions tests/system/large/functions/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,7 +1722,7 @@ def serialize_row(row):
)

serialize_row_remote = session.remote_function(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=str,
reuse=False,
cloud_function_service_account="default",
Expand Down Expand Up @@ -1771,7 +1771,7 @@ def analyze(row):
)

analyze_remote = session.remote_function(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=str,
reuse=False,
cloud_function_service_account="default",
Expand Down Expand Up @@ -1895,7 +1895,7 @@ def serialize_row(row):
)

serialize_row_remote = session.remote_function(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=str,
reuse=False,
cloud_function_service_account="default",
Expand Down Expand Up @@ -1944,7 +1944,7 @@ def test_df_apply_axis_1_na_nan_inf(session):

try:

def float_parser(row):
def float_parser(row: pandas.Series):
import numpy as mynp
import pandas as mypd

Expand All @@ -1955,7 +1955,6 @@ def float_parser(row):
return float(row["text"])

float_parser_remote = session.remote_function(
input_types=bigframes.series.Series,
output_type=float,
reuse=False,
cloud_function_service_account="default",
Expand Down Expand Up @@ -2055,12 +2054,12 @@ def test_df_apply_axis_1_series_args(session, scalars_dfs):
try:

@session.remote_function(
input_types=[bigframes.series.Series, float, str, bool],
input_types=[pandas.Series, float, str, bool],
output_type=list[str],
reuse=False,
cloud_function_service_account="default",
)
def foo_list(x, y0: float, y1, y2) -> list[str]:
def foo_list(x: pandas.Series, y0: float, y1, y2) -> list[str]:
return (
[str(x["int64_col"]), str(y0), str(y1), str(y2)]
if y2
Expand Down Expand Up @@ -3087,12 +3086,21 @@ def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs):
try:

# The return type has to be bool type for callable where condition.
def is_sum_positive_series(s):
def is_sum_positive_series(s: pandas.Series) -> bool:
return s["int64_col"] + s["int64_too"] > 0

with pytest.raises(
TypeError,
match="Argument type hint must be Pandas Series, not BigFrames Series.",
):
session.remote_function(
input_types=bigframes.series.Series,
dataset=dataset_id,
reuse=False,
cloud_function_service_account="default",
)(is_sum_positive_series)

is_sum_positive_series_mf = session.remote_function(
input_types=bigframes.series.Series,
output_type=bool,
dataset=dataset_id,
reuse=False,
cloud_function_service_account="default",
Expand Down
13 changes: 6 additions & 7 deletions tests/system/small/functions/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import bigframes_vendored.constants as constants
import google.api_core.exceptions
from google.cloud import bigquery
import pandas
import pandas as pd
import pyarrow
import pytest
Expand Down Expand Up @@ -1166,16 +1167,14 @@ def test_df_apply_axis_1(session, scalars_dfs, dataset_id_permanent):
]
scalars_df, scalars_pandas_df = scalars_dfs

def add_ints(row):
def add_ints(row: pandas.Series) -> int:
return row["int64_col"] + row["int64_too"]

with pytest.warns(
bigframes.exceptions.PreviewWarning,
match="input_types=Series is in preview.",
):
add_ints_remote = session.remote_function(
input_types=bigframes.series.Series,
output_type=int,
dataset=dataset_id_permanent,
name=get_function_name(add_ints, is_row_processor=True),
cloud_function_service_account="default",
Expand Down Expand Up @@ -1223,11 +1222,11 @@ def test_df_apply_axis_1_ordering(session, scalars_dfs, dataset_id_permanent):
ordering_columns = ["bool_col", "int64_col"]
scalars_df, scalars_pandas_df = scalars_dfs

def add_ints(row):
def add_ints(row: pandas.Series) -> int:
return row["int64_col"] + row["int64_too"]

add_ints_remote = session.remote_function(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=int,
dataset=dataset_id_permanent,
name=get_function_name(add_ints, is_row_processor=True),
Expand Down Expand Up @@ -1267,7 +1266,7 @@ def add_numbers(row):
return row["x"] + row["y"]

add_numbers_remote = session.remote_function(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=float,
dataset=dataset_id_permanent,
name=get_function_name(add_numbers, is_row_processor=True),
Expand Down Expand Up @@ -1321,7 +1320,7 @@ def echo_len(row):
return len(row)

echo_len_remote = session.remote_function(
input_types=bigframes.series.Series,
input_types=pandas.Series,
output_type=float,
dataset=dataset_id_permanent,
name=get_function_name(echo_len, is_row_processor=True),
Expand Down
19 changes: 3 additions & 16 deletions tests/unit/functions/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,12 @@
import pandas
import pytest

import bigframes.exceptions
import bigframes.functions.function as bff
import bigframes.series
from bigframes.testing import mocks


@pytest.mark.parametrize(
"series_type",
(
pytest.param(
pandas.Series,
id="pandas.Series",
),
pytest.param(
bigframes.series.Series,
id="bigframes.series.Series",
),
),
)
def test_series_input_types_to_str(series_type):
def test_series_input_types_to_str():
"""Check that is_row_processor=True uses str as the input type to serialize a row."""
session = mocks.create_bigquery_session()
remote_function_decorator = bff.remote_function(
Expand All @@ -48,7 +35,7 @@ def test_series_input_types_to_str(series_type):
):

@remote_function_decorator
def axis_1_function(myparam: series_type) -> str: # type: ignore
def axis_1_function(myparam: pandas.Series) -> str: # type: ignore
return "Hello, " + myparam["str_col"] + "!" # type: ignore

# Still works as a normal function.
Expand Down