Skip to content

Commit 109f486

Browse files
committed
fix: add warnings for duplicated or conflicting type hints in bigframes function
1 parent 770918e commit 109f486

File tree

5 files changed

+108
-24
lines changed

5 files changed

+108
-24
lines changed

bigframes/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ class PreviewWarning(Warning):
4141
"""The feature is in preview."""
4242

4343

44+
class FunctionRedundantTypeHintWarning(UserWarning):
45+
"""Redundant or conflicting type hints in a BigFrames function."""
46+
47+
4448
class NullIndexPreviewWarning(PreviewWarning):
4549
"""Unused. Kept for backwards compatibility.
4650

bigframes/functions/_function_session.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,11 @@ def wrapper(func):
534534
**signature_kwargs,
535535
)
536536
if input_types is not None:
537+
if _utils.has_input_type(py_sig):
538+
msg = bfe.format_message(
539+
"Redundant or conflicting input types detected, using the one from the decorator."
540+
)
541+
warnings.warn(msg, category=bfe.FunctionRedundantTypeHintWarning)
537542
if not isinstance(input_types, collections.abc.Sequence):
538543
input_types = [input_types]
539544
py_sig = py_sig.replace(
@@ -543,6 +548,11 @@ def wrapper(func):
543548
]
544549
)
545550
if output_type:
551+
if _utils.has_output_type(py_sig):
552+
msg = bfe.format_message(
553+
"Redundant or conflicting return type detected, using the one from the decorator."
554+
)
555+
warnings.warn(msg, category=bfe.FunctionRedundantTypeHintWarning)
546556
py_sig = py_sig.replace(return_annotation=output_type)
547557

548558
# Try to get input types via type annotations.
@@ -836,6 +846,11 @@ def wrapper(func):
836846
**signature_kwargs,
837847
)
838848
if input_types is not None:
849+
if _utils.has_input_type(py_sig):
850+
msg = bfe.format_message(
851+
"Redundant or conflicting input types detected, using the one from the decorator."
852+
)
853+
warnings.warn(msg, category=bfe.FunctionRedundantTypeHintWarning)
839854
if not isinstance(input_types, collections.abc.Sequence):
840855
input_types = [input_types]
841856
py_sig = py_sig.replace(
@@ -845,6 +860,11 @@ def wrapper(func):
845860
]
846861
)
847862
if output_type:
863+
if _utils.has_output_type(py_sig):
864+
msg = bfe.format_message(
865+
"Redundant or conflicting return type detected, using the one from the decorator."
866+
)
867+
warnings.warn(msg, category=bfe.FunctionRedundantTypeHintWarning)
848868
py_sig = py_sig.replace(return_annotation=output_type)
849869

850870
udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)

bigframes/functions/_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import hashlib
17+
import inspect
1718
import json
1819
import sys
1920
import typing
@@ -269,3 +270,16 @@ def post_process(input):
269270
return bbq.json_extract_string_array(input, value_dtype=result_dtype)
270271

271272
return post_process
273+
274+
275+
def has_input_type(signature: inspect.Signature) -> bool:
276+
"""Checks if any parameter in the signature has a type annotation."""
277+
for param in signature.parameters.values():
278+
if param.annotation is not inspect.Parameter.empty:
279+
return True
280+
return False
281+
282+
283+
def has_output_type(signature: inspect.Signature) -> bool:
284+
"""Checks if the signature has a return type annotation."""
285+
return signature.return_annotation is not inspect.Parameter.empty

tests/system/large/functions/test_managed_function.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
16+
1517
import google.api_core.exceptions
1618
import pandas
1719
import pyarrow
@@ -31,12 +33,21 @@
3133
def test_managed_function_array_output(session, scalars_dfs, dataset_id):
3234
try:
3335

34-
@session.udf(
35-
dataset=dataset_id,
36-
name=prefixer.create_prefix(),
36+
with warnings.catch_warnings(record=True) as record:
37+
38+
@session.udf(
39+
dataset=dataset_id,
40+
name=prefixer.create_prefix(),
41+
)
42+
def featurize(x: int) -> list[float]:
43+
return [float(i) for i in [x, x + 1, x + 2]]
44+
45+
input_type_warning = "Redundant or conflicting input types detected."
46+
return_type_warning = "Redundant or conflicting return type detected"
47+
assert not any(input_type_warning in str(warning.message) for warning in record)
48+
assert not any(
49+
return_type_warning in str(warning.message) for warning in record
3750
)
38-
def featurize(x: int) -> list[float]:
39-
return [float(i) for i in [x, x + 1, x + 2]]
4051

4152
scalars_df, scalars_pandas_df = scalars_dfs
4253

@@ -222,7 +233,10 @@ def add(x: int, y: int) -> int:
222233
def test_managed_function_series_combine_array_output(session, dataset_id, scalars_dfs):
223234
try:
224235

225-
def add_list(x: int, y: int) -> list[int]:
236+
# The type hints in this function's signature are redundant. The
237+
# `input_types` and `output_type` arguments from udf decorator take
238+
# precedence and will be used instead.
239+
def add_list(x, y: bool) -> list[bool]:
226240
return [x, y]
227241

228242
scalars_df, scalars_pandas_df = scalars_dfs
@@ -234,9 +248,18 @@ def add_list(x: int, y: int) -> list[int]:
234248
# Make sure there are NA values in the test column.
235249
assert any([pandas.isna(val) for val in bf_df[int_col_name_with_nulls]])
236250

237-
add_list_managed_func = session.udf(
238-
dataset=dataset_id, name=prefixer.create_prefix()
239-
)(add_list)
251+
with warnings.catch_warnings(record=True) as record:
252+
add_list_managed_func = session.udf(
253+
input_types=[int, int],
254+
output_type=list[int],
255+
dataset=dataset_id,
256+
name=prefixer.create_prefix(),
257+
)(add_list)
258+
259+
input_type_warning = "Redundant or conflicting input types detected"
260+
assert any(input_type_warning in str(warning.message) for warning in record)
261+
return_type_warning = "Redundant or conflicting return type detected"
262+
assert any(return_type_warning in str(warning.message) for warning in record)
240263

241264
# After filtering out nulls the managed function application should work
242265
# similar to pandas.

tests/system/large/functions/test_remote_function.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -843,22 +843,31 @@ def test_remote_function_with_external_package_dependencies(
843843
):
844844
try:
845845

846-
def pd_np_foo(x):
846+
# The return type hint in this function's signature is redundant. The
847+
# `output_type` argument from remote_function decorator takes precedence
848+
# and will be used instead.
849+
def pd_np_foo(x) -> None:
847850
import numpy as mynp
848851
import pandas as mypd
849852

850853
return mypd.Series([x, mynp.sqrt(mynp.abs(x))]).sum()
851854

852-
# Create the remote function with the name provided explicitly
853-
pd_np_foo_remote = session.remote_function(
854-
input_types=[int],
855-
output_type=float,
856-
dataset=dataset_id,
857-
bigquery_connection=bq_cf_connection,
858-
reuse=False,
859-
packages=["numpy", "pandas >= 2.0.0"],
860-
cloud_function_service_account="default",
861-
)(pd_np_foo)
855+
with warnings.catch_warnings(record=True) as record:
856+
# Create the remote function with the name provided explicitly
857+
pd_np_foo_remote = session.remote_function(
858+
input_types=[int],
859+
output_type=float,
860+
dataset=dataset_id,
861+
bigquery_connection=bq_cf_connection,
862+
reuse=False,
863+
packages=["numpy", "pandas >= 2.0.0"],
864+
cloud_function_service_account="default",
865+
)(pd_np_foo)
866+
867+
input_type_warning = "Redundant or conflicting input types detected"
868+
assert not any(input_type_warning in str(warning.message) for warning in record)
869+
return_type_warning = "Redundant or conflicting return type detected"
870+
assert any(return_type_warning in str(warning.message) for warning in record)
862871

863872
# The behavior of the created remote function should be as expected
864873
scalars_df, scalars_pandas_df = scalars_dfs
@@ -1999,10 +2008,24 @@ def test_remote_function_unnamed_removed_w_session_cleanup():
19992008
# create a clean session
20002009
session = bigframes.connect()
20012010

2002-
# create an unnamed remote function in the session
2003-
@session.remote_function(reuse=False, cloud_function_service_account="default")
2004-
def foo(x: int) -> int:
2005-
return x + 1
2011+
with warnings.catch_warnings(record=True) as record:
2012+
# create an unnamed remote function in the session.
2013+
# The type hints in this function's signature are redundant. The
2014+
# `input_types` and `output_type` arguments from remote_function
2015+
# decorator take precedence and will be used instead.
2016+
@session.remote_function(
2017+
input_types=[int],
2018+
output_type=int,
2019+
reuse=False,
2020+
cloud_function_service_account="default",
2021+
)
2022+
def foo(x: int) -> int:
2023+
return x + 1
2024+
2025+
input_type_warning = "Redundant or conflicting input types detected"
2026+
assert any(input_type_warning in str(warning.message) for warning in record)
2027+
return_type_warning = "Redundant or conflicting return type detected"
2028+
assert any(return_type_warning in str(warning.message) for warning in record)
20062029

20072030
# ensure that remote function artifacts are created
20082031
assert foo.bigframes_remote_function is not None

0 commit comments

Comments
 (0)