diff --git a/tests/unit/functions/test_remote_function_utils.py b/tests/unit/functions/test_remote_function_utils.py index 0e4ca7a2ac..687c599985 100644 --- a/tests/unit/functions/test_remote_function_utils.py +++ b/tests/unit/functions/test_remote_function_utils.py @@ -217,13 +217,62 @@ def test_package_existed_helper(): assert not _utils._package_existed([], "pandas") +# Helper functions for signature inspection tests +def _func_one_arg_annotated(x: int) -> int: + """A function with one annotated arg and an annotated return type.""" + return x + + +def _func_one_arg_unannotated(x): + """A function with one unannotated arg and no return type annotation.""" + return x + + +def _func_two_args_annotated(x: int, y: str): + """A function with two annotated args and no return type annotation.""" + return f"{x}{y}" + + +def _func_two_args_unannotated(x, y): + """A function with two unannotated args and no return type annotation.""" + return f"{x}{y}" + + +def test_has_conflict_input_type_too_few_inputs(): + """Tests conflict when there are fewer input types than parameters.""" + signature = inspect.signature(_func_one_arg_annotated) + assert _utils.has_conflict_input_type(signature, input_types=[]) + + +def test_has_conflict_input_type_too_many_inputs(): + """Tests conflict when there are more input types than parameters.""" + signature = inspect.signature(_func_one_arg_annotated) + assert _utils.has_conflict_input_type(signature, input_types=[int, str]) + + +def test_has_conflict_input_type_type_mismatch(): + """Tests has_conflict_input_type with a conflicting type annotation.""" + signature = inspect.signature(_func_two_args_annotated) + + # The second type (bool) conflicts with the annotation (str). + assert _utils.has_conflict_input_type(signature, input_types=[int, bool]) + + +def test_has_conflict_input_type_no_conflict_annotated(): + """Tests that a matching, annotated signature is compatible.""" + signature = inspect.signature(_func_two_args_annotated) + assert not _utils.has_conflict_input_type(signature, input_types=[int, str]) + + +def test_has_conflict_input_type_no_conflict_unannotated(): + """Tests that a signature with no annotations is always compatible.""" + signature = inspect.signature(_func_two_args_unannotated) + assert not _utils.has_conflict_input_type(signature, input_types=[int, float]) + + def test_has_conflict_output_type_no_conflict(): """Tests has_conflict_output_type with type annotation.""" - # Helper functions with type annotation for has_conflict_output_type. - def _func_with_return_type(x: int) -> int: - return x - - signature = inspect.signature(_func_with_return_type) + signature = inspect.signature(_func_one_arg_annotated) assert _utils.has_conflict_output_type(signature, output_type=float) assert not _utils.has_conflict_output_type(signature, output_type=int) @@ -231,11 +280,7 @@ def _func_with_return_type(x: int) -> int: def test_has_conflict_output_type_no_annotation(): """Tests has_conflict_output_type without type annotation.""" - # Helper functions without type annotation for has_conflict_output_type. - def _func_without_return_type(x): - return x - - signature = inspect.signature(_func_without_return_type) + signature = inspect.signature(_func_one_arg_unannotated) assert not _utils.has_conflict_output_type(signature, output_type=int) assert not _utils.has_conflict_output_type(signature, output_type=float)