From 1502994ac1caa3f6e956bf76f18abc41b9b9df82 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:09:06 +0200 Subject: [PATCH 1/2] Add support callable protocols for instance factory dependency injection (#755). --- CHANGELOG.rst | 2 + jsonargparse/_actions.py | 24 +++---- jsonargparse/_signatures.py | 2 +- jsonargparse/_typehints.py | 98 ++++++++++++++++++---------- jsonargparse_tests/test_typehints.py | 67 +++++++++++++++++++ 5 files changed, 147 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f0271368..66e58753 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,6 +19,8 @@ Added ^^^^^ - Support for Python 3.14 (`#753 `__). +- Support callable protocols for instance factory dependency injection (`#758 + `__). Fixed ^^^^^ diff --git a/jsonargparse/_actions.py b/jsonargparse/_actions.py index 087b500d..4982975a 100644 --- a/jsonargparse/_actions.py +++ b/jsonargparse/_actions.py @@ -350,7 +350,6 @@ def __init__(self, typehint=None, **kwargs): if typehint is not None: self._typehint = typehint else: - self._typehint = kwargs.pop("_typehint") self.update_init_kwargs(kwargs) super().__init__(**kwargs) @@ -363,13 +362,14 @@ def update_init_kwargs(self, kwargs): is_protocol, ) - typehint = get_unaliased_type(get_optional_arg(self._typehint)) + typehint = get_unaliased_type(get_optional_arg(kwargs.pop("_typehint"))) if get_typehint_origin(typehint) is not Union: assert "nargs" not in kwargs kwargs["nargs"] = "?" - self._basename = iter_to_set_str(get_subclass_names(self._typehint, callable_return=True)) + self._typehint = typehint + self._basename = iter_to_set_str(get_subclass_names(typehint, callable_return=True)) self._baseclasses = get_subclass_types(typehint, callable_return=True) - assert self._baseclasses + assert self._baseclasses and all(isinstance(b, type) for b in self._baseclasses) self._kind = "subclass of" if any(is_protocol(b) for b in self._baseclasses): @@ -391,28 +391,28 @@ def __call__(self, *args, **kwargs): def print_help(self, call_args): from ._typehints import ( - ActionTypeHint, - get_optional_arg, - get_unaliased_type, + adapt_partial_callable_class, implements_protocol, resolve_class_path_by_name, ) parser, _, value, option_string = call_args try: - typehint = get_unaliased_type(get_optional_arg(self._typehint)) if self.nargs == "?" and value is None: - val_class = typehint + val_class = self._typehint else: - val_class = import_object(resolve_class_path_by_name(typehint, value)) + val_class = import_object(resolve_class_path_by_name(self._baseclasses, value)) except Exception as ex: raise TypeError(f"{option_string}: {ex}") from ex + if not any(is_subclass(val_class, b) or implements_protocol(val_class, b) for b in self._baseclasses): raise TypeError(f'{option_string}: Class "{value}" is not a {self._kind} {self._basename}') dest = re.sub("\\.help$", "", self.dest) subparser = type(parser)(description=f"Help for {option_string}={get_import_path(val_class)}") - if ActionTypeHint.is_callable_typehint(typehint) and hasattr(typehint, "__args__"): - self.sub_add_kwargs["skip"] = {max(0, len(typehint.__args__) - 1)} + val = Namespace(class_path=get_import_path(val_class)) + _, partial_skip_args = adapt_partial_callable_class(self._typehint, val) + if partial_skip_args: + self.sub_add_kwargs["skip"] = partial_skip_args subparser.add_class_arguments(val_class, dest, **self.sub_add_kwargs) subparser._inner_parser = True remove_actions(subparser, (_HelpAction, _ActionPrintConfig, _ActionConfigLoad)) diff --git a/jsonargparse/_signatures.py b/jsonargparse/_signatures.py index 2ab26c57..3af652d8 100644 --- a/jsonargparse/_signatures.py +++ b/jsonargparse/_signatures.py @@ -267,7 +267,7 @@ def _add_signature_arguments( """ params = get_signature_parameters(function_or_class, method_name, logger=self.logger) - skip_positionals = [s for s in (skip or []) if isinstance(s, int)] + skip_positionals = [s for s in (skip or []) if isinstance(s, int) and s != 0] if skip_positionals: if len(skip_positionals) > 1 or any(p <= 0 for p in skip_positionals): raise ValueError(f"Unexpected number of positionals to skip: {skip_positionals}") diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 24d6c717..0f83a333 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -271,14 +271,14 @@ def normalize_default(self, default): default = default.name elif is_callable_type(self._typehint) and callable(default) and not inspect.isclass(default): default = get_import_path(default) + elif ActionTypeHint.is_return_subclass_typehint(self._typehint) and inspect.isclass(default): + default = {"class_path": get_import_path(default)} elif is_subclass_type and not allow_default_instance.get(): from ._parameter_resolvers import UnknownDefault default_type = type(default) if not is_subclass(default_type, UnknownDefault) and self.is_subclass_typehint(default_type): raise ValueError("Subclass types require as default either a dict with class_path or a lazy instance.") - elif ActionTypeHint.is_return_subclass_typehint(self._typehint) and inspect.isclass(default): - default = {"class_path": get_import_path(default)} return default @staticmethod @@ -352,7 +352,7 @@ def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False): def is_return_subclass_typehint(typehint): typehint = get_unaliased_type(get_optional_arg(get_unaliased_type(typehint))) typehint_origin = get_typehint_origin(typehint) - if typehint_origin in callable_origin_types: + if typehint_origin in callable_origin_types or is_instance_factory_protocol(typehint): return_type = get_callable_return_type(typehint) if ActionTypeHint.is_subclass_typehint(return_type): return True @@ -626,15 +626,18 @@ def instantiate_classes(self, value): return value if islist else value[0] @staticmethod - def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0): + def get_class_parser(val_class, sub_add_kwargs=None, skip_args=None): if isinstance(val_class, str): val_class = import_object(val_class) kwargs = dict(sub_add_kwargs) if sub_add_kwargs else {} if skip_args: - kwargs.setdefault("skip", set()).add(skip_args) + kwargs.setdefault("skip", set()).update(skip_args) if is_subclass_spec(kwargs.get("default")): kwargs["default"] = kwargs["default"].get("init_args") parser = parent_parser.get() + from ._core import ArgumentParser + + assert isinstance(parser, ArgumentParser) parser = type(parser)(exit_on_error=False, logger=parser.logger, parser_mode=parser.parser_mode) remove_actions(parser, (ActionConfigFile, _ActionPrintConfig)) if inspect.isclass(val_class) or inspect.isclass(get_typehint_origin(val_class)): @@ -658,11 +661,10 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0): def extra_help(self): extra = "" typehint = get_optional_arg(self._typehint) - if self.is_subclass_typehint(typehint, all_subtypes=False) or get_typehint_origin( - typehint - ) in callable_origin_types.union({Type, type}): - if self.is_callable_typehint(typehint) and getattr(typehint, "__args__", None): - typehint = get_callable_return_type(get_optional_arg(typehint)) + typehint = get_callable_return_type(typehint) or typehint + if get_typehint_origin(typehint) is type: + typehint = typehint.__args__[0] + if self.is_subclass_typehint(typehint, all_subtypes=False): class_paths = get_all_subclass_paths(typehint) if class_paths: extra = ", known subclasses: " + ", ".join(class_paths) @@ -967,11 +969,15 @@ def adapt_typehints( val = adapt_typehints(val, subtypehints[0], **adapt_kwargs) # Callable - elif typehint_origin in callable_origin_types or typehint in callable_origin_types: + elif ( + typehint_origin in callable_origin_types + or typehint in callable_origin_types + or is_instance_factory_protocol(typehint, logger) + ): if serialize: if is_subclass_spec(val): - val, _, num_partial_args = adapt_partial_callable_class(typehint, val) - val = adapt_class_type(val, True, False, sub_add_kwargs, skip_args=num_partial_args) + val, partial_skip_args = adapt_partial_callable_class(typehint, val) + val = adapt_class_type(val, True, False, sub_add_kwargs, partial_skip_args=partial_skip_args) else: val = object_path_serializer(val) else: @@ -1000,12 +1006,13 @@ def adapt_typehints( raise ImportError( f"Dict must include a class_path and optionally init_args, but got {val_input}" ) - val, partial_classes, num_partial_args = adapt_partial_callable_class(typehint, val) + val, partial_skip_args = adapt_partial_callable_class(typehint, val) val_class = import_object(val["class_path"]) - if inspect.isclass(val_class) and not (partial_classes or callable_instances(val_class)): + if inspect.isclass(val_class) and not (partial_skip_args or callable_instances(val_class)): + base_type = get_callable_return_type(typehint) or typehint raise ImportError( f"Expected '{val['class_path']}' to be a class that instantiates into callable " - f"or a subclass of {partial_classes}." + f"or a subclass of {base_type}." ) val["class_path"] = get_import_path(val_class) val = adapt_class_type( @@ -1013,8 +1020,7 @@ def adapt_typehints( False, instantiate_classes, sub_add_kwargs, - skip_args=num_partial_args, - partial_classes=partial_classes, + partial_skip_args=partial_skip_args, prev_val=prev_val, ) except (ImportError, AttributeError, ArgumentError) as ex: @@ -1172,6 +1178,15 @@ def is_instance_or_supports_protocol(value, class_type): return isinstance(value, class_type) +def is_instance_factory_protocol(class_type, logger=None): + if not is_protocol(class_type) or not callable_instances(class_type): + return False + from ._postponed_annotations import get_return_type + + return_type = get_return_type(class_type.__call__, logger) + return ActionTypeHint.is_subclass_typehint(return_type) + + def is_subclass_spec(val): is_class = isinstance(val, (dict, Namespace)) and "class_path" in val if is_class: @@ -1214,9 +1229,14 @@ def subclass_spec_as_namespace(val, prev_val=None): def get_callable_return_type(typehint): return_type = None - args = getattr(typehint, "__args__", None) - if isinstance(args, tuple) and len(args) > 0: - return_type = args[-1] + if is_instance_factory_protocol(typehint): + from ._postponed_annotations import get_return_type + + return_type = get_return_type(typehint.__call__) + elif get_typehint_origin(typehint) in callable_origin_types: + args = getattr(typehint, "__args__", None) + if isinstance(args, tuple) and len(args) > 0: + return_type = args[-1] return return_type @@ -1238,7 +1258,7 @@ def yield_subclass_types(typehint, also_lists=False, callable_return=False): return typehint = get_unaliased_type(get_optional_arg(get_unaliased_type(typehint))) typehint_origin = get_typehint_origin(typehint) - if callable_return and typehint_origin in callable_origin_types: + if callable_return and (typehint_origin in callable_origin_types or is_instance_factory_protocol(typehint)): return_type = get_callable_return_type(typehint) if return_type: k = {"also_lists": also_lists, "callable_return": callable_return} @@ -1261,8 +1281,7 @@ def get_subclass_names(typehint, callable_return=False): def adapt_partial_callable_class(callable_type, subclass_spec): - partial_classes = False - num_partial_args = 0 + partial_skip_args = None return_type = get_callable_return_type(callable_type) if return_type: subclass_types = get_subclass_types(return_type) @@ -1270,9 +1289,18 @@ def adapt_partial_callable_class(callable_type, subclass_spec): if subclass_types and is_subclass(class_type, subclass_types): subclass_spec = subclass_spec.clone() subclass_spec["class_path"] = get_import_path(class_type) - partial_classes = True - num_partial_args = len(callable_type.__args__) - 1 - return subclass_spec, partial_classes, num_partial_args + if is_protocol(callable_type): + from ._parameter_resolvers import get_signature_parameters + + params = get_signature_parameters(callable_type, "__call__") + partial_skip_args = set() + positionals = [p for p in params if "POSITIONAL_ONLY" in str(p.kind)] + if positionals: + partial_skip_args.add(len(positionals)) + partial_skip_args.update(p.name for p in params if "POSITIONAL_ONLY" not in str(p.kind)) + else: + partial_skip_args = {len(callable_type.__args__) - 1} + return subclass_spec, partial_skip_args def get_all_subclass_paths(cls: Type) -> List[str]: @@ -1318,9 +1346,15 @@ def add_subclasses(cl): return subclass_list -def resolve_class_path_by_name(cls: Type, name: str) -> str: +def resolve_class_path_by_name(cls: Union[Type, Tuple[Type]], name: str) -> str: class_path = name if "." not in class_path: + if isinstance(cls, tuple): + for cls_n in cls: + class_path = resolve_class_path_by_name(cls_n, name) + if "." in class_path: + break + return class_path subclass_dict = defaultdict(list) for subclass in get_all_subclass_paths(cls): subclass_name = subclass.rsplit(".", 1)[1] @@ -1376,13 +1410,11 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_val, value): ) -def adapt_class_type( - value, serialize, instantiate_classes, sub_add_kwargs, prev_val=None, skip_args=0, partial_classes=False -): +def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev_val=None, partial_skip_args=None): prev_val = subclass_spec_as_namespace(prev_val) value = subclass_spec_as_namespace(value) val_class = import_object(value.class_path) - parser = ActionTypeHint.get_class_parser(val_class, sub_add_kwargs, skip_args=skip_args) + parser = ActionTypeHint.get_class_parser(val_class, sub_add_kwargs, skip_args=partial_skip_args) # No need to re-create the linked arg but just "inform" the corresponding parser actions that it exists upstream. for target in sub_add_kwargs.get("linked_targets", []): @@ -1415,7 +1447,7 @@ def adapt_class_type( instantiator_fn = get_class_instantiator() - if partial_classes: + if partial_skip_args: return partial( instantiator_fn, val_class, diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index a9f87c34..1d07daab 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -21,6 +21,7 @@ Literal, Mapping, Optional, + Protocol, Sequence, Set, Tuple, @@ -1058,6 +1059,72 @@ def test_callable_args_return_type_class(parser, subtests): assert "--optimizer.params" not in help_str +class OptimizerFactory(Protocol): + def __call__(self, params: List[float]) -> Optimizer: ... + + +class DifferentParamsOrder(Optimizer): + def __init__(self, lr: float, params: List[float], momentum: float = 0.0): + super().__init__(lr=lr, params=params, momentum=momentum) + + +def test_callable_protocol_instance_factory(parser, subtests): + parser.add_argument("--optimizer", type=OptimizerFactory, default=SGD) + + with subtests.test("default"): + cfg = parser.get_defaults() + assert cfg.optimizer.class_path == f"{__name__}.SGD" + init = parser.instantiate_classes(cfg) + optimizer = init.optimizer(params=[1, 2]) + assert isinstance(optimizer, SGD) + assert optimizer.params == [1, 2] + assert optimizer.lr == 1e-3 + assert optimizer.momentum == 0.0 + + with subtests.test("parse dict"): + value = { + "class_path": "Adam", + "init_args": { + "lr": 0.01, + "momentum": 0.9, + }, + } + cfg = parser.parse_args([f"--optimizer={json.dumps(value)}"]) + init = parser.instantiate_classes(cfg) + optimizer = init.optimizer(params=[3, 2, 1]) + assert isinstance(optimizer, Adam) + assert optimizer.params == [3, 2, 1] + assert optimizer.lr == 0.01 + assert optimizer.momentum == 0.9 + + with subtests.test("params order"): + value = { + "class_path": "DifferentParamsOrder", + "init_args": { + "lr": 0.1, + "momentum": 0.8, + }, + } + cfg = parser.parse_args([f"--optimizer={json.dumps(value)}"]) + init = parser.instantiate_classes(cfg) + optimizer = init.optimizer(params=[3, 2]) + assert isinstance(optimizer, DifferentParamsOrder) + assert optimizer.params == [3, 2] + assert optimizer.lr == 0.1 + assert optimizer.momentum == 0.8 + dump = parser.dump(cfg) + assert json_or_yaml_load(dump) == cfg.as_dict() + + with subtests.test("help"): + help_str = get_parser_help(parser) + assert "--optimizer.help" in help_str + assert "Show the help for the given subclass or implementer of protocol {Optimizer,OptimizerFactory" in help_str + help_str = get_parse_args_stdout(parser, [f"--optimizer.help={__name__}.DifferentParamsOrder"]) + assert f"Help for --optimizer.help={__name__}.DifferentParamsOrder" in help_str + assert "--optimizer.lr" in help_str + assert "--optimizer.params" not in help_str + + def test_optional_callable_return_type_help(parser): parser.add_argument("--optimizer", type=Optional[Callable[[List[float]], Optimizer]]) help_str = get_parser_help(parser) From ae39652fc27c4ba2a56ac8cb5054fd50b02d92bf Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Mon, 18 Aug 2025 07:16:51 +0200 Subject: [PATCH 2/2] Documentation and a test for positional --- DOCUMENTATION.rst | 330 +++++++++++++++++---------- jsonargparse/_typehints.py | 2 - jsonargparse_tests/test_typehints.py | 22 ++ sphinx/conf.py | 2 +- 4 files changed, 234 insertions(+), 122 deletions(-) diff --git a/DOCUMENTATION.rst b/DOCUMENTATION.rst index c922ffd5..1980903c 100644 --- a/DOCUMENTATION.rst +++ b/DOCUMENTATION.rst @@ -1035,96 +1035,9 @@ A second option is a class that once instantiated becomes callable: >>> init.callable(5) 8 -The third option is only applicable when the type is a callable that has a class -as return type or a ``Union`` including a class. This is useful to support -dependency injection for classes that require a parameter that is only available -after injection. The parser supports this automatically by providing a function -that receives this parameter and returns the instance of the class. Take for -example the classes: - -.. testcode:: callable - - class Optimizer: - def __init__(self, params: Iterable): - self.params = params - - - class SGD(Optimizer): - def __init__(self, params: Iterable, lr: float): - super().__init__(params) - self.lr = lr - -.. testcode:: callable - :hide: - - doctest_mock_class_in_main(SGD) - -A possible parser and callable behavior would be: - -.. doctest:: callable - - >>> value = { - ... "class_path": "SGD", - ... "init_args": { - ... "lr": 0.01, - ... }, - ... } - - >>> parser.add_argument("--optimizer", type=Callable[[Iterable], Optimizer]) # doctest: +IGNORE_RESULT - >>> cfg = parser.parse_args(["--optimizer", str(value)]) - >>> cfg.optimizer - Namespace(class_path='__main__.SGD', init_args=Namespace(lr=0.01)) - >>> init = parser.instantiate_classes(cfg) - >>> optimizer = init.optimizer([1, 2, 3]) - >>> isinstance(optimizer, SGD) - True - >>> optimizer.params, optimizer.lr - ([1, 2, 3], 0.01) - -Multiple arguments available after injection are also supported and can be -specified the same way with a ``Callable`` type hint. For example, for two -``Iterable`` arguments, you can use the following syntax: ``Callable[[Iterable, -Iterable], Type]``. Please be aware that the arguments are passed as positional -arguments, this means that the injected function would be called like -``function(value1, value2)``. Similarly, for a callable that accepts zero -arguments, the syntax would be ``Callable[[], Type]``. - -.. note:: - - When the ``Callable`` has a class return type, it is possible to specify the - ``class_path`` giving only its name if imported before parsing, as explained - in :ref:`sub-classes-command-line`. - -If the same type above is used as type hint of a parameter of another class, a -default can be set using a lambda, for example: - -.. testcode:: callable - - class Model: - def __init__( - self, - optimizer: Callable[[Iterable], Optimizer] = lambda p: SGD(p, lr=0.05), - ): - self.optimizer = optimizer - -Then a parser and behavior could be: - -.. code-block:: - - >>> parser.add_class_arguments(Model, 'model') - >>> cfg = parser.get_defaults() - >>> cfg.model.optimizer - Namespace(class_path='__main__.SGD', init_args=Namespace(lr=0.05)) - >>> init = parser.instantiate_classes(cfg) - >>> optimizer = init.model.optimizer([1, 2, 3]) - >>> optimizer.params, optimizer.lr - ([1, 2, 3], 0.05) - -See :ref:`ast-resolver` for limitations of lambda defaults in signatures. -Providing a lambda default to :py:meth:`.ActionsContainer.add_argument` does not -work since there is no AST resolving. In this case, a dict with ``class_path`` -and ``init_args`` can be used as default. - +The third option is only applicable when the type is a callable that returns +class instances. This is a form of :ref:`dependency-injection`, so this third +case is explained in section :ref:`instance-factories`. .. _registering-types: @@ -1967,28 +1880,55 @@ the stubs. In these cases in the parser help the default is shown as ``Unknown`` and not included in :py:meth:`.ArgumentParser.get_defaults` or the output of ``--print_config``. + +.. _dependency-injection: + +Dependency injection +==================== + +Dependency injection is a software design pattern that separates the +instantiation details of objects from their usage, resulting in more loosely +coupled programs, see the `wikipedia article +`__. Because of its +benefits, support for dependency injection has been a design goal of +jsonargparse. + +In python, dependency injection is achieved by: + +- Using as type hint a class, such that the parameter accepts an instance of + this class or any subclass, e.g. ``module: ModuleBaseClass``. +- Using as type hint a callable that returns an instance of a class, such that + the parameter accepts a function for instantiation. This could be either + using ``Callable``, e.g. ``module: Callable[[int], ModuleBaseClass]``, or a + protocol, e.g. ``module: ModuleFactoryProtocol``. + .. _sub-classes: Class type and sub-classes -========================== - -It is possible to use an arbitrary class as a type such that the argument -accepts an instance of this class or any derived subclass. This practice is -known as `dependency injection -`__. In the config file a -class is represented by a dictionary with a ``class_path`` entry indicating the -dot notation expression to import the class, and optionally some ``init_args`` -that would be used to instantiate it. When parsing, it will be checked that the -class can be imported, that it is a subclass of the given type and that -``init_args`` values correspond to valid arguments to instantiate it. After -parsing, the config object will include the ``class_path`` and ``init_args`` -entries. To get a config object with all sub-classes instantiated, the -:py:meth:`.ArgumentParser.instantiate_classes` method is used. The ``skip`` -parameter of the signature methods can also be used to exclude arguments within -subclasses. This is done by giving its relative destination key, i.e. as -``param.init_args.subparam``. - -A simple example would be having some config file ``config.yaml`` as: +-------------------------- + +When a class is used as a type hint, jsonargparse expects in config files a +dictionary with a ``class_path`` entry indicating the dot notation expression to +import the class, and optionally some ``init_args`` that would be used to +instantiate it. When parsing, it will be checked that the class can be imported, +that it is a subclass of the given type and that ``init_args`` values correspond +to valid arguments to instantiate it. After parsing, the config object will +include the ``class_path`` and ``init_args`` entries. To get a config object +with all nested sub-classes instantiated, the +:py:meth:`.ArgumentParser.instantiate_classes` method is used. + +Additional to using a class as type hint in signatures, for low level +construction of parsers, there are also the methods +:py:meth:`.SignatureArguments.add_class_arguments` and +:py:meth:`.SignatureArguments.add_subclass_arguments`. These methods accept a +``skip`` argument that can be used to exclude parameters within subclasses. This +is done by giving its relative destination key, i.e. as +``param.init_args.subparam``. An individual argument can also be added having as +type a class, i.e. ``parser.add_argument("--module", type=ModuleBase)``. + +A simple example with a top-level class to instantiate, with a parameter that +expects an injected class instance, would be having some config file +``config.yaml`` as: .. code-block:: yaml @@ -2030,6 +1970,10 @@ Then in python: {'class_path': 'calendar.Calendar', 'init_args': {'firstweekday': 1}} >>> cfg = parser.instantiate_classes(cfg) + >>> isinstance(cfg.myclass, MyClass) + True + >>> isinstance(cfg.myclass.calendar, Calendar) + True >>> cfg.myclass.calendar.getfirstweekday() 1 @@ -2037,13 +1981,20 @@ In this example the ``class_path`` points to the same class used for the type. But a subclass of ``Calendar`` with an extended set of init parameters would also work. -An individual argument can also be added having as type a class, i.e. -``parser.add_argument('--calendar', type=Calendar)``. There is also another -method :py:meth:`.SignatureArguments.add_subclass_arguments` which does the same -as ``add_argument``, but has some added benefits: 1) the argument is added in a -new group automatically; 2) the argument values can be given in an independent -config file by specifying a path to it; and 3) by default sets a useful -``metavar`` and ``help`` strings. +If the previous example were changed to use +:py:meth:`.SignatureArguments.add_subclass_arguments` instead of +:py:meth:`.SignatureArguments.add_class_arguments`, then subclasses ``MyClass`` +would also be accepted. In this case the config would be like: + +.. code-block:: yaml + + myclass: + class_path: my_module.MyClass + init_args: + calendar: + class_path: calendar.TextCalendar + init_args: + firstweekday: 1 .. note:: @@ -2057,14 +2008,149 @@ config file by specifying a path to it; and 3) by default sets a useful type a class. The accepted ``init_args`` would be the parameters of that function. +.. _instance-factories: + +Instance factories +------------------ + +As explained at the beginning of section :ref:`dependency-injection`, callables +that return instances of classes, referred to as instance factories, represent +an alternative approach to dependency injection. This is useful to support +dependency injection of classes that require parameters that are only available +after injection. For this case, when +:py:meth:`.ArgumentParser.instantiate_classes` is run, a partial function is +provided, which might accept parameters and returns the instance of the class. +Two options are possible, either using ``Callable`` or ``Protocol``. First to +illustrate the ``Callable`` option, take for example the classes: + +.. testcode:: callable + + class Optimizer: + def __init__(self, params: Iterable): + self.params = params + + + class SGD(Optimizer): + def __init__(self, params: Iterable, lr: float): + super().__init__(params) + self.lr = lr + +.. testcode:: callable + :hide: + + doctest_mock_class_in_main(SGD) + +A possible parser and callable behavior would be: + +.. doctest:: callable + + >>> value = { + ... "class_path": "SGD", + ... "init_args": { + ... "lr": 0.01, + ... }, + ... } + + >>> parser.add_argument("--optimizer", type=Callable[[Iterable], Optimizer]) # doctest: +IGNORE_RESULT + >>> cfg = parser.parse_args(["--optimizer", str(value)]) + >>> cfg.optimizer + Namespace(class_path='__main__.SGD', init_args=Namespace(lr=0.01)) + >>> init = parser.instantiate_classes(cfg) + >>> optimizer = init.optimizer([1, 2, 3]) + >>> isinstance(optimizer, SGD) + True + >>> optimizer.params, optimizer.lr + ([1, 2, 3], 0.01) + +.. note:: + + When the ``Callable`` has a class return type, it is possible to specify the + ``class_path`` giving only its name if imported before parsing, as explained + in :ref:`sub-classes-command-line`. + +If the same type above is used as type hint of a parameter of another class, a +default can be set using a lambda, for example: + +.. testcode:: callable + + class Model: + def __init__( + self, + optimizer: Callable[[Iterable], Optimizer] = lambda p: SGD(p, lr=0.05), + ): + self.optimizer = optimizer + +Then a parser and behavior could be: + +.. code-block:: + + >>> parser.add_class_arguments(Model, 'model') + >>> cfg = parser.get_defaults() + >>> cfg.model.optimizer + Namespace(class_path='__main__.SGD', init_args=Namespace(lr=0.05)) + >>> init = parser.instantiate_classes(cfg) + >>> optimizer = init.model.optimizer([1, 2, 3]) + >>> optimizer.params, optimizer.lr + ([1, 2, 3], 0.05) + +See :ref:`ast-resolver` for limitations of lambda defaults in signatures. +Providing a lambda default to :py:meth:`.ActionsContainer.add_argument` does not +work since there is no AST resolving. In this case, a dict with ``class_path`` +and ``init_args`` can be used as default. + +Multiple arguments required after injection is also supported and can be +specified the same way with a ``Callable``. For example, for two +``Iterable`` arguments, you can use the syntax: ``Callable[[Iterable, +Iterable], Type]``. Similarly, for a callable that accepts zero +arguments, the syntax would be ``Callable[[], Type]``. + +Note the big limitation that ``Callable`` has. It is only possible to specify +positional and unnamed parameters. To overcome this limitation, the second +option, a callable ``Protocol`` can be used instead. Building up from the same +example, an ``OptimizerFactory`` protocol can be defined as: + +.. testcode:: callable + + class OptimizerFactory(Protocol): + def __call__(self, params: Iterable) -> Optimizer: ... + +Then a parser and protocol behavior would be: + +.. testcode:: callable + :hide: + + parser = ArgumentParser() + +.. doctest:: callable + + >>> value = { + ... "class_path": "SGD", + ... "init_args": { + ... "lr": 0.02, + ... }, + ... } + + >>> parser.add_argument("--optimizer", type=OptimizerFactory) # doctest: +IGNORE_RESULT + >>> cfg = parser.parse_args(["--optimizer", str(value)]) + >>> cfg.optimizer + Namespace(class_path='__main__.SGD', init_args=Namespace(lr=0.02)) + >>> init = parser.instantiate_classes(cfg) + >>> optimizer = init.optimizer(params=[6, 5]) + >>> optimizer.params, optimizer.lr + ([6, 5], 0.02) + +The key difference with respect to the ``Callable`` is being able to call +``init.optimizer()`` with keyword arguments ``params=[6, 5]``. + .. _sub-classes-command-line: Command line ------------ -The help of the parser does not show details for a type class since this depends -on the subclass. To get details for a particular subclass there is a specific -help option that receives the import path. Take for example a parser defined as: +The help of the parser does not show accepted parameters of a class since this +depends on the chosen subclass. To get details for a particular subclass there +is a help option that receives the import path. Take for example a parser +defined as: .. testcode:: @@ -2163,6 +2249,12 @@ example above, this would be: Like this, the parsed default will be a dict with ``class_path`` and ``init_args``, again avoiding the risk of mutability. +The use of :func:`.lazy_instance` is somewhat discouraged. A function that +delays the initialization of instances, and works for all possible cases out +there, is challenging. The current implementation is known to have some +problems. Instead of using :func:`.lazy_instance`, you could consider switching +to :ref:`instance-factories`. + .. note:: In python there can be some classes or functions for which it is not diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 0f83a333..18ba9cf0 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -1332,8 +1332,6 @@ def add_subclasses(cl): add_subclasses(subclass) if get_typehint_origin(cls) in callable_origin_types: - if len(getattr(cls, "__args__", [])) < 2: - return subclass_list cls = cls.__args__[-1] if get_typehint_origin(cls) in {Union, Type, type}: diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index 1d07daab..f591c10f 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -1125,6 +1125,28 @@ def test_callable_protocol_instance_factory(parser, subtests): assert "--optimizer.params" not in help_str +class OptimizerFactoryPositionalAndKeyword(Protocol): + def __call__(self, lr: float, /, params: List[float]) -> Optimizer: ... + + +def test_callable_protocol_instance_factory_with_positional(parser): + parser.add_argument("--optimizer", type=OptimizerFactoryPositionalAndKeyword) + + value = { + "class_path": "DifferentParamsOrder", + "init_args": { + "momentum": 0.9, + }, + } + cfg = parser.parse_args([f"--optimizer={json.dumps(value)}"]) + init = parser.instantiate_classes(cfg) + optimizer = init.optimizer(0.2, params=[0, 1]) + assert optimizer.lr == 0.2 + assert optimizer.params == [0, 1] + assert optimizer.momentum == 0.9 + assert isinstance(optimizer, DifferentParamsOrder) + + def test_optional_callable_return_type_help(parser): parser.add_argument("--optimizer", type=Optional[Callable[[List[float]], Optimizer]]) help_str = get_parser_help(parser) diff --git a/sphinx/conf.py b/sphinx/conf.py index 98f282db..93d5050d 100644 --- a/sphinx/conf.py +++ b/sphinx/conf.py @@ -82,7 +82,7 @@ def check_output(self, want, got, optionflags): from calendar import Calendar from dataclasses import dataclass from io import StringIO -from typing import Callable, Iterable, List +from typing import Callable, Iterable, List, Protocol import jsonargparse_tests from jsonargparse import * from jsonargparse.typing import *