Skip to content

Commit 26499ff

Browse files
committed
only warm conflict
1 parent 109f486 commit 26499ff

File tree

5 files changed

+62
-40
lines changed

5 files changed

+62
-40
lines changed

bigframes/exceptions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,6 @@ class PreviewWarning(Warning):
4141
"""The feature is in preview."""
4242

4343

44-
class FunctionRedundantTypeHintWarning(UserWarning):
45-
"""Redundant or conflicting type hints in a BigFrames function."""
46-
47-
4844
class NullIndexPreviewWarning(PreviewWarning):
4945
"""Unused. Kept for backwards compatibility.
5046
@@ -107,6 +103,10 @@ class FunctionAxisOnePreviewWarning(PreviewWarning):
107103
"""Remote Function and Managed UDF with axis=1 preview."""
108104

109105

106+
class FunctionConflictTypeHintWarning(UserWarning):
107+
"""Conflicting type hints in a BigFrames function."""
108+
109+
110110
def format_message(message: str, fill: bool = True):
111111
"""Formats a warning message with ANSI color codes for the warning color.
112112

bigframes/functions/_function_session.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -534,25 +534,25 @@ def wrapper(func):
534534
**signature_kwargs,
535535
)
536536
if input_types is not None:
537-
if _utils.has_input_type(py_sig):
538-
msg = bfe.format_message(
539-
"Redundant or conflicting input types detected, using the one from the decorator."
540-
)
541-
warnings.warn(msg, category=bfe.FunctionRedundantTypeHintWarning)
542537
if not isinstance(input_types, collections.abc.Sequence):
543538
input_types = [input_types]
539+
if _utils.has_conflict_input_type(py_sig, input_types):
540+
msg = bfe.format_message(
541+
"Conflicting input types detected, using the one from the decorator."
542+
)
543+
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
544544
py_sig = py_sig.replace(
545545
parameters=[
546546
par.replace(annotation=itype)
547547
for par, itype in zip(py_sig.parameters.values(), input_types)
548548
]
549549
)
550550
if output_type:
551-
if _utils.has_output_type(py_sig):
551+
if _utils.has_conflict_output_type(py_sig, output_type):
552552
msg = bfe.format_message(
553-
"Redundant or conflicting return type detected, using the one from the decorator."
553+
"Conflicting return type detected, using the one from the decorator."
554554
)
555-
warnings.warn(msg, category=bfe.FunctionRedundantTypeHintWarning)
555+
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
556556
py_sig = py_sig.replace(return_annotation=output_type)
557557

558558
# Try to get input types via type annotations.
@@ -846,25 +846,25 @@ def wrapper(func):
846846
**signature_kwargs,
847847
)
848848
if input_types is not None:
849-
if _utils.has_input_type(py_sig):
850-
msg = bfe.format_message(
851-
"Redundant or conflicting input types detected, using the one from the decorator."
852-
)
853-
warnings.warn(msg, category=bfe.FunctionRedundantTypeHintWarning)
854849
if not isinstance(input_types, collections.abc.Sequence):
855850
input_types = [input_types]
851+
if _utils.has_conflict_input_type(py_sig, input_types):
852+
msg = bfe.format_message(
853+
"Conflicting input types detected, using the one from the decorator."
854+
)
855+
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
856856
py_sig = py_sig.replace(
857857
parameters=[
858858
par.replace(annotation=itype)
859859
for par, itype in zip(py_sig.parameters.values(), input_types)
860860
]
861861
)
862862
if output_type:
863-
if _utils.has_output_type(py_sig):
863+
if _utils.has_conflict_output_type(py_sig, output_type):
864864
msg = bfe.format_message(
865-
"Redundant or conflicting return type detected, using the one from the decorator."
865+
"Conflicting return type detected, using the one from the decorator."
866866
)
867-
warnings.warn(msg, category=bfe.FunctionRedundantTypeHintWarning)
867+
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
868868
py_sig = py_sig.replace(return_annotation=output_type)
869869

870870
udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)

bigframes/functions/_utils.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import json
1919
import sys
2020
import typing
21-
from typing import cast, Optional, Set
21+
from typing import Any, cast, Optional, Sequence, Set
2222

2323
import cloudpickle
2424
import google.api_core.exceptions
@@ -272,14 +272,34 @@ def post_process(input):
272272
return post_process
273273

274274

275-
def has_input_type(signature: inspect.Signature) -> bool:
276-
"""Checks if any parameter in the signature has a type annotation."""
277-
for param in signature.parameters.values():
275+
def has_conflict_input_type(
276+
signature: inspect.Signature,
277+
input_types: Sequence[Any],
278+
) -> bool:
279+
"""Checks if the parameters have any conflict with the input_types."""
280+
params = list(signature.parameters.values())
281+
282+
if len(params) != len(input_types):
283+
return True
284+
285+
# Check for conflicts type hints.
286+
for i, param in enumerate(params):
278287
if param.annotation is not inspect.Parameter.empty:
279-
return True
288+
if param.annotation != input_types[i]:
289+
return True
290+
291+
# No conflicts were found after checking all parameters.
280292
return False
281293

282294

283-
def has_output_type(signature: inspect.Signature) -> bool:
284-
"""Checks if the signature has a return type annotation."""
285-
return signature.return_annotation is not inspect.Parameter.empty
295+
def has_conflict_output_type(
296+
signature: inspect.Signature,
297+
output_type: Any,
298+
) -> bool:
299+
"""Checks if the return type annotation conflicts with the output_type."""
300+
return_annotation = signature.return_annotation
301+
302+
if return_annotation is inspect.Parameter.empty:
303+
return False
304+
305+
return return_annotation != output_type

tests/system/large/functions/test_managed_function.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ def test_managed_function_array_output(session, scalars_dfs, dataset_id):
4242
def featurize(x: int) -> list[float]:
4343
return [float(i) for i in [x, x + 1, x + 2]]
4444

45-
input_type_warning = "Redundant or conflicting input types detected."
46-
return_type_warning = "Redundant or conflicting return type detected"
45+
# No following conflict warning when there is no redundant type hints.
46+
input_type_warning = "Conflicting input types detected"
47+
return_type_warning = "Conflicting return type detected"
4748
assert not any(input_type_warning in str(warning.message) for warning in record)
4849
assert not any(
4950
return_type_warning in str(warning.message) for warning in record
@@ -233,7 +234,7 @@ def add(x: int, y: int) -> int:
233234
def test_managed_function_series_combine_array_output(session, dataset_id, scalars_dfs):
234235
try:
235236

236-
# The type hints in this function's signature are redundant. The
237+
# The type hints in this function's signature has conflicts. The
237238
# `input_types` and `output_type` arguments from udf decorator take
238239
# precedence and will be used instead.
239240
def add_list(x, y: bool) -> list[bool]:
@@ -256,9 +257,9 @@ def add_list(x, y: bool) -> list[bool]:
256257
name=prefixer.create_prefix(),
257258
)(add_list)
258259

259-
input_type_warning = "Redundant or conflicting input types detected"
260+
input_type_warning = "Conflicting input types detected"
260261
assert any(input_type_warning in str(warning.message) for warning in record)
261-
return_type_warning = "Redundant or conflicting return type detected"
262+
return_type_warning = "Conflicting return type detected"
262263
assert any(return_type_warning in str(warning.message) for warning in record)
263264

264265
# After filtering out nulls the managed function application should work

tests/system/large/functions/test_remote_function.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ def test_remote_function_with_external_package_dependencies(
843843
):
844844
try:
845845

846-
# The return type hint in this function's signature is redundant. The
846+
# The return type hint in this function's signature has conflict. The
847847
# `output_type` argument from remote_function decorator takes precedence
848848
# and will be used instead.
849849
def pd_np_foo(x) -> None:
@@ -864,9 +864,9 @@ def pd_np_foo(x) -> None:
864864
cloud_function_service_account="default",
865865
)(pd_np_foo)
866866

867-
input_type_warning = "Redundant or conflicting input types detected"
867+
input_type_warning = "Conflicting input types detected"
868868
assert not any(input_type_warning in str(warning.message) for warning in record)
869-
return_type_warning = "Redundant or conflicting return type detected"
869+
return_type_warning = "Conflicting return type detected"
870870
assert any(return_type_warning in str(warning.message) for warning in record)
871871

872872
# The behavior of the created remote function should be as expected
@@ -2022,10 +2022,11 @@ def test_remote_function_unnamed_removed_w_session_cleanup():
20222022
def foo(x: int) -> int:
20232023
return x + 1
20242024

2025-
input_type_warning = "Redundant or conflicting input types detected"
2026-
assert any(input_type_warning in str(warning.message) for warning in record)
2027-
return_type_warning = "Redundant or conflicting return type detected"
2028-
assert any(return_type_warning in str(warning.message) for warning in record)
2025+
# No following warning with only redundant type hints (no conflict).
2026+
input_type_warning = "Conflicting input types detected"
2027+
assert not any(input_type_warning in str(warning.message) for warning in record)
2028+
return_type_warning = "Conflicting return type detected"
2029+
assert not any(return_type_warning in str(warning.message) for warning in record)
20292030

20302031
# ensure that remote function artifacts are created
20312032
assert foo.bigframes_remote_function is not None

0 commit comments

Comments
 (0)