diff --git a/processor/flow.py b/processor/flow.py index 0a417c1..b064d4c 100644 --- a/processor/flow.py +++ b/processor/flow.py @@ -567,8 +567,6 @@ def __init__( """ del input_volinfo_or_ts_spec - self._config = config - if config.patch_size % config.stride != 0: raise ValueError( f'patch_size {config.patch_size} not a multiple of stride' @@ -583,13 +581,20 @@ def __init__( ) if config.mask_configs: - config.mask_configs = self._get_mask_configs(config.mask_configs) + config = dataclasses.replace( + config, mask_configs=self._get_mask_configs(config.mask_configs) + ) if config.selection_mask_configs: - config.selection_mask_configs = self._get_mask_configs( - config.selection_mask_configs + config.selection_mask_configs = dataclasses.replace( + config, + selection_mask_configs=self._get_mask_configs( + config.selection_mask_configs + ), ) + self._config = config + def _build_mask( self, mask_configs: mask_lib.MaskConfigs,