diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 05bbfb6b..bf77c646 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -32,6 +32,8 @@ Fixed `__). - ``default_config_files`` with settings for multiple subcommands not working correctly (`#819 `__). +- ``register_type`` not checking that the given type is a class (`#820 + `__). Changed ^^^^^^^ diff --git a/jsonargparse/typing.py b/jsonargparse/typing.py index de1eb7a9..6b3d9bbe 100644 --- a/jsonargparse/typing.py +++ b/jsonargparse/typing.py @@ -309,7 +309,7 @@ def deserializer(self, value): def register_type( - type_class: Any, + type_class: type, serializer: Callable = str, deserializer: Optional[Callable] = None, deserializer_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = ( @@ -324,7 +324,7 @@ def register_type( """Registers a new type for use in jsonargparse parsers. Args: - type_class: The type object to be registered. + type_class: The class to be registered. serializer: Function that converts an instance of the class to a basic type. deserializer: Function that converts a basic type to an instance of the class. Default instantiates type_class. deserializer_exceptions: Exceptions that deserializer raises when it fails. @@ -332,6 +332,8 @@ def register_type( fail_already_registered: Whether to fail if type has already been registered. uniqueness_key: Key to determine uniqueness of type. """ + if not inspect.isclass(type_class): + raise ValueError(f"Expected type_class to be a class, got {type_class!r}") type_handler = RegisteredType(type_class, serializer, deserializer, deserializer_exceptions, type_check) fail_already_registered = globals().get("_fail_already_registered", fail_already_registered) if not uniqueness_key and fail_already_registered and get_registered_type(type_class): diff --git a/jsonargparse_tests/test_typing.py b/jsonargparse_tests/test_typing.py index e9aee0be..5c375779 100644 --- a/jsonargparse_tests/test_typing.py +++ b/jsonargparse_tests/test_typing.py @@ -376,6 +376,14 @@ def deserializer(v): pytest.raises(ValueError, lambda: register_type(datetime)) # different registration not okay +def test_register_not_a_class_type_failure(): + class SomeClass: + pass + + with pytest.raises(ValueError, match="Expected type_class to be a class"): + register_type(Union[SomeClass, int]) + + class RegisterOnFirstUse: pass