Skip to content

Commit ff2c599

Browse files
authored
Fix linking entire dataclasses on instantiation (#746)
1 parent 6ebe74e commit ff2c599

File tree

3 files changed

+68
-9
lines changed

3 files changed

+68
-9
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Fixed
1919
^^^^^
2020
- Subclass defaults incorrectly taken from base class (`#743
2121
<https://github.com/omni-us/jsonargparse/pull/743>`__).
22+
- Linking entire dataclasses on instantiation not working (`#746
23+
<https://github.com/omni-us/jsonargparse/pull/746>`__).
2224

2325

2426
v4.40.1 (2025-07-24)

jsonargparse/_link_arguments.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_ActionPrintConfig,
1717
_ActionSubCommands,
1818
_find_parent_action,
19+
_find_parent_action_and_subcommand,
1920
filter_default_actions,
2021
)
2122
from ._namespace import Namespace, split_key, split_key_leaf
@@ -47,7 +48,7 @@ def find_subclass_action_or_class_group(
4748
parser: ArgumentParser,
4849
key: str,
4950
exclude: Optional[Union[Type[ArgparseAction], Tuple[Type[ArgparseAction], ...]]] = None,
50-
) -> Optional[Union[ArgparseAction, "ArgumentGroup"]]:
51+
) -> Optional[Union[ArgparseAction, ArgumentGroup]]:
5152
from ._typehints import ActionTypeHint
5253

5354
action = _find_parent_action(parser, key, exclude=exclude)
@@ -142,23 +143,33 @@ def __init__(
142143
]
143144

144145
# Set and check target action
145-
self.target = (target, _find_parent_action(parser, target, exclude=exclude))
146+
self.target = (target, find_parent_action_or_group(parser, target, exclude=exclude))
146147
for key, action in self.source + [self.target]:
147148
if action is None:
148149
raise ValueError(f'No action for key "{key}".')
149150
assert self.target[1] is not None
150151

152+
from ._core import ArgumentGroup
151153
from ._typehints import ActionTypeHint
152154

155+
is_target_group = isinstance(self.target[1], ArgumentGroup)
153156
is_target_subclass = ActionTypeHint.is_subclass_typehint(self.target[1], all_subtypes=False, also_lists=True)
154157
valid_target_init_arg = is_target_subclass and target.startswith(f"{self.target[1].dest}.init_args.")
155158
valid_target_leaf = self.target[1].dest == target
156159
if not valid_target_leaf and is_target_subclass and not valid_target_init_arg:
157160
prefix = f"{self.target[1].dest}.init_args."
158161
raise ValueError(f'Target key expected to start with "{prefix}", got "{target}".')
159162

163+
# Remove target group and child actions
164+
if is_target_group:
165+
parser._action_groups.remove(self.target[1])
166+
del parser.groups[target]
167+
for action in list(parser._actions):
168+
if action.dest == target or action.dest.startswith(f"{target}."):
169+
parser._actions.remove(action)
160170
# Replace target action with link action
161-
if not is_target_subclass or valid_target_leaf:
171+
elif not is_target_subclass or valid_target_leaf:
172+
assert isinstance(self.target[1], ArgparseAction)
162173
for key in self.target[1].option_strings:
163174
parser._option_string_actions[key] = self
164175
parser._actions[parser._actions.index(self.target[1])] = self
@@ -181,7 +192,7 @@ def __init__(
181192
if target in parser.required_args:
182193
parser.required_args.remove(target)
183194
if is_target_subclass and not valid_target_leaf:
184-
sub_add_kwargs = self.target[1].sub_add_kwargs # type: ignore[attr-defined]
195+
sub_add_kwargs = self.target[1].sub_add_kwargs # type: ignore[union-attr]
185196
if "linked_targets" not in sub_add_kwargs:
186197
sub_add_kwargs["linked_targets"] = set()
187198
subtarget = target.split(".init_args.", 1)[1]
@@ -209,10 +220,15 @@ def __init__(
209220
type_attr = None
210221
help_str = f"Use --{self.target[1].dest}.help for details."
211222
else:
212-
type_attr = getattr(self.target[1], "_typehint", self.target[1].type)
213-
help_str = self.target[1].help
223+
if is_target_group:
224+
type_attr = self.target[1].group_class # type: ignore[union-attr]
225+
help_str = self.target[1].title # type: ignore[union-attr]
226+
else:
227+
assert isinstance(self.target[1], ArgparseAction)
228+
type_attr = getattr(self.target[1], "_typehint", self.target[1].type)
229+
help_str = self.target[1].help
214230
if help_str == import_module("jsonargparse._formatters").empty_help:
215-
help_str = f"Target argument '{self.target[1].dest}' lacks type and help"
231+
help_str = f"Target '{self.target[1].dest}' lacks type and help"
216232

217233
super().__init__(
218234
[link_str],
@@ -392,8 +408,9 @@ def set_target_value(action: "ActionLink", value: Any, cfg: Namespace, logger) -
392408

393409
if ActionTypeHint.is_subclass_typehint(target_action, all_subtypes=False, also_lists=True):
394410
if target_key == target_action.dest:
395-
target_action._check_type(value) # type: ignore[attr-defined]
411+
target_action._check_type(value) # type: ignore[union-attr]
396412
else:
413+
assert isinstance(target_action.dest, str)
397414
parent = cfg.get(target_action.dest)
398415
child_key = target_key[len(target_action.dest) + 1 :]
399416
if isinstance(parent, list) and any(isinstance(i, Namespace) and child_key in i for i in parent):
@@ -473,6 +490,17 @@ def del_target_key(target_key):
473490
ActionLink.strip_link_target_keys(subparsers[num], cfg[subcommand])
474491

475492

493+
def find_parent_action_or_group(
494+
parser: ArgumentParser,
495+
key: str,
496+
exclude: Optional[Union[Type[ArgparseAction], Tuple[Type[ArgparseAction], ...]]] = None,
497+
) -> Optional[Union[ArgparseAction, ArgumentGroup]]:
498+
action_or_group = _find_parent_action_and_subcommand(parser, key, exclude=exclude)[0]
499+
if not action_or_group and parser.groups and key in parser.groups:
500+
return parser.groups[key]
501+
return action_or_group
502+
503+
476504
def get_link_actions(parser: ArgumentParser, apply_on: str, skip=set()) -> List[ActionLink]:
477505
if not hasattr(parser, "_links_group"):
478506
return []

jsonargparse_tests/test_link_arguments.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_on_parse_help_target_lacking_type_and_help(parser):
2525
parser.add_argument("--b")
2626
parser.link_arguments("a", "b")
2727
help_str = get_parser_help(parser)
28-
assert "Target argument 'b' lacks type and help" in help_str
28+
assert "Target 'b' lacks type and help" in help_str
2929

3030

3131
def test_on_parse_shallow_print_config(parser):
@@ -953,6 +953,35 @@ def test_on_instantiate_targets_passed_to_instantiator(parser):
953953
assert init.model.applied_instantiation_links == {"model.init_args.optimizer.init_args.num_classes": 7}
954954

955955

956+
@dataclass
957+
class DataDep:
958+
param: int = 1
959+
960+
961+
@dataclass
962+
class DepContainer:
963+
dep: DataDep
964+
ref: str = ""
965+
966+
967+
def test_on_instantiate_target_entire_dataclass(parser, tmp_cwd):
968+
parser.add_class_arguments(DataDep, "data")
969+
parser.add_class_arguments(DepContainer, "container")
970+
parser.link_arguments("data", "container.dep", apply_on="instantiate")
971+
972+
defaults = parser.get_defaults()
973+
assert defaults == Namespace(data=Namespace(param=1), container=Namespace(ref=""))
974+
cfg = parser.parse_args(["--data.param=2", "--container.ref=x"])
975+
assert cfg == Namespace(data=Namespace(param=2), container=Namespace(ref="x"))
976+
init = parser.instantiate_classes(cfg)
977+
assert init.data is init.container.dep
978+
assert init.container.dep.param == 2
979+
980+
help_str = get_parser_help(parser)
981+
assert "data --> container.dep [applied on instantiate]" in help_str
982+
assert "--container.dep" not in help_str
983+
984+
956985
# link creation failures
957986

958987

0 commit comments

Comments
 (0)