Skip to content

Commit 8795fb0

Browse files
authored
Fix register_type not checking that the given type is a class (#820)
1 parent 5d7ecb4 commit 8795fb0

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ Fixed
3232
<https://github.com/omni-us/jsonargparse/pull/818>`__).
3333
- ``default_config_files`` with settings for multiple subcommands not working
3434
correctly (`#819 <https://github.com/omni-us/jsonargparse/pull/819>`__).
35+
- ``register_type`` not checking that the given type is a class (`#820
36+
<https://github.com/omni-us/jsonargparse/pull/820>`__).
3537

3638
Changed
3739
^^^^^^^

jsonargparse/typing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def deserializer(self, value):
309309

310310

311311
def register_type(
312-
type_class: Any,
312+
type_class: type,
313313
serializer: Callable = str,
314314
deserializer: Optional[Callable] = None,
315315
deserializer_exceptions: Union[type[Exception], tuple[type[Exception], ...]] = (
@@ -324,14 +324,16 @@ def register_type(
324324
"""Registers a new type for use in jsonargparse parsers.
325325
326326
Args:
327-
type_class: The type object to be registered.
327+
type_class: The class to be registered.
328328
serializer: Function that converts an instance of the class to a basic type.
329329
deserializer: Function that converts a basic type to an instance of the class. Default instantiates type_class.
330330
deserializer_exceptions: Exceptions that deserializer raises when it fails.
331331
type_check: Function to check if a value is of type_class. Gets as arguments the value and type_class.
332332
fail_already_registered: Whether to fail if type has already been registered.
333333
uniqueness_key: Key to determine uniqueness of type.
334334
"""
335+
if not inspect.isclass(type_class):
336+
raise ValueError(f"Expected type_class to be a class, got {type_class!r}")
335337
type_handler = RegisteredType(type_class, serializer, deserializer, deserializer_exceptions, type_check)
336338
fail_already_registered = globals().get("_fail_already_registered", fail_already_registered)
337339
if not uniqueness_key and fail_already_registered and get_registered_type(type_class):

jsonargparse_tests/test_typing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,14 @@ def deserializer(v):
376376
pytest.raises(ValueError, lambda: register_type(datetime)) # different registration not okay
377377

378378

379+
def test_register_not_a_class_type_failure():
380+
class SomeClass:
381+
pass
382+
383+
with pytest.raises(ValueError, match="Expected type_class to be a class"):
384+
register_type(Union[SomeClass, int])
385+
386+
379387
class RegisterOnFirstUse:
380388
pass
381389

0 commit comments

Comments
 (0)