@@ -125,12 +125,22 @@ def _load_data(self, data: np.ndarray | str | Path, name: str) -> np.ndarray:
125125
126126 def _validate_data (self ):
127127 """Validate data shapes and types."""
128+ # Handle 2D data: (C, H, W) → (C, 1, H, W)
129+ if self .predictions .ndim == 3 :
130+ print (f" 📐 2D data detected, expanding predictions: { self .predictions .shape } → { self .predictions .shape [:1 ] + (1 ,) + self .predictions .shape [1 :]} " )
131+ self .predictions = self .predictions [:, np .newaxis , :, :]
132+
128133 # Predictions should be (C, D, H, W)
129134 if self .predictions .ndim != 4 :
130135 raise ValueError (
131136 f"Predictions should be 4D (C, D, H, W), got shape { self .predictions .shape } "
132137 )
133138
139+ # Handle 2D ground truth: (H, W) → (1, H, W)
140+ if self .ground_truth .ndim == 2 :
141+ print (f" 📐 2D ground truth detected, expanding: { self .ground_truth .shape } → { (1 ,) + self .ground_truth .shape } " )
142+ self .ground_truth = self .ground_truth [np .newaxis , :, :]
143+
134144 # Ground truth should be (D, H, W)
135145 if self .ground_truth .ndim != 3 :
136146 raise ValueError (
@@ -145,8 +155,12 @@ def _validate_data(self):
145155 f"ground_truth { self .ground_truth .shape } "
146156 )
147157
148- # Check mask if provided
158+ # Handle 2D mask if provided
149159 if self .mask is not None :
160+ if self .mask .ndim == 2 :
161+ print (f" 📐 2D mask detected, expanding: { self .mask .shape } → { (1 ,) + self .mask .shape } " )
162+ self .mask = self .mask [np .newaxis , :, :]
163+
150164 if self .mask .shape != self .ground_truth .shape :
151165 raise ValueError (
152166 f"Mask shape { self .mask .shape } doesn't match "
@@ -170,7 +184,7 @@ def optimize(self) -> optuna.Study:
170184 direction = self ._get_optimization_direction ()
171185
172186 # Create storage directory if using SQLite
173- storage = self .tune_cfg . get ( "storage" , None )
187+ storage = getattr ( self .tune_cfg , "storage" , None )
174188 if storage and storage .startswith ("sqlite:///" ):
175189 # Extract database file path from SQLite URL
176190 db_path = storage .replace ("sqlite:///" , "" )
@@ -204,7 +218,7 @@ def optimize(self) -> optuna.Study:
204218 self ._objective ,
205219 n_trials = n_trials ,
206220 timeout = timeout ,
207- show_progress_bar = self .tune_cfg .logging . get ( "show_progress_bar" , True ),
221+ show_progress_bar = getattr ( self .tune_cfg .logging , "show_progress_bar" , True ),
208222 )
209223
210224 # Print results
@@ -219,7 +233,7 @@ def _create_sampler(self) -> optuna.samplers.BaseSampler:
219233 """Create Optuna sampler from config."""
220234 sampler_cfg = self .tune_cfg .sampler
221235 sampler_name = sampler_cfg ["name" ]
222- sampler_kwargs = sampler_cfg . get ( "kwargs" , {})
236+ sampler_kwargs = getattr ( sampler_cfg , "kwargs" , {})
223237
224238 # Convert OmegaConf to dict
225239 if isinstance (sampler_kwargs , DictConfig ):
@@ -236,13 +250,13 @@ def _create_sampler(self) -> optuna.samplers.BaseSampler:
236250
237251 def _create_pruner (self ) -> Optional [optuna .pruners .BasePruner ]:
238252 """Create Optuna pruner from config."""
239- pruner_cfg = self .tune_cfg . get ( "pruner" , None )
253+ pruner_cfg = getattr ( self .tune_cfg , "pruner" , None )
240254
241- if pruner_cfg is None or not pruner_cfg . get ( "enabled" , False ):
255+ if pruner_cfg is None or not getattr ( pruner_cfg , "enabled" , False ):
242256 return None
243257
244- pruner_name = pruner_cfg . get ( "name" , "Median" )
245- pruner_kwargs = pruner_cfg . get ( "kwargs" , {})
258+ pruner_name = getattr ( pruner_cfg , "name" , "Median" )
259+ pruner_kwargs = getattr ( pruner_cfg , "kwargs" , {})
246260
247261 # Convert OmegaConf to dict
248262 if isinstance (pruner_kwargs , DictConfig ):
@@ -338,7 +352,7 @@ def _objective(self, trial: optuna.Trial) -> float:
338352 )
339353
340354 # Print progress
341- if self .tune_cfg .logging . get ( "verbose" , True ):
355+ if getattr ( self .tune_cfg .logging , "verbose" , True ):
342356 print (f"Trial { self .trial_count :3d} : { metric_name } ={ metric_value :.4f} " )
343357
344358 return metric_value
@@ -546,7 +560,7 @@ def _print_results(self, study: optuna.Study):
546560 for key , value in best_decoding_params .items ():
547561 print (f" { key } : { value } " )
548562
549- if self .param_space_cfg . get ( "postprocessing" , {}). get ( "enabled" , False ):
563+ if getattr ( self .param_space_cfg , "postprocessing" , None ) and getattr ( self . param_space_cfg . postprocessing , "enabled" , False ):
550564 best_postproc_params = self ._reconstruct_postproc_params (study .best_params )
551565 if best_postproc_params :
552566 print (f"\n Post-processing params:" )
@@ -662,8 +676,13 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
662676 print ("\n [1/4] Running inference on tuning dataset..." )
663677
664678 # Get tune config sections (used later for loading predictions, ground truth, masks)
665- tune_data = cfg .tune .get ("data" , {})
666- tune_output = cfg .tune .get ("output" , {})
679+ tune_data = getattr (cfg .tune , "data" , None )
680+ tune_output = getattr (cfg .tune , "output" , None )
681+
682+ if tune_data is None :
683+ raise ValueError ("Missing tune.data in configuration" )
684+ if tune_output is None :
685+ raise ValueError ("Missing tune.output in configuration" )
667686
668687 # Create datamodule with tune mode (reads from cfg.tune.data)
669688 # Uses inference settings from cfg.inference (sliding window, TTA, save_predictions, etc.)
@@ -677,8 +696,8 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
677696
678697 # Step 2: Load predictions from saved files
679698 print ("\n [2/4] Loading predictions from saved files..." )
680- output_pred_dir = tune_output . get ( "output_pred" , str (output_dir .parent / "results" ))
681- cache_suffix = tune_output . get ( "cache_suffix" , "_tta_prediction.h5" )
699+ output_pred_dir = getattr ( tune_output , "output_pred" , str (output_dir .parent / "results" ))
700+ cache_suffix = getattr ( tune_output , "cache_suffix" , "_tta_prediction.h5" )
682701 predictions_dir = Path (output_pred_dir )
683702
684703 # Find all prediction files using cache_suffix from config
@@ -711,13 +730,21 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
711730
712731 # Step 3: Load ground truth
713732 print ("\n [3/4] Loading ground truth labels..." )
714- tune_label_pattern = tune_data . get ( "tune_label" , None )
733+ tune_label_pattern = getattr ( tune_data , "tune_label" , None )
715734
716735 if tune_label_pattern is None :
717736 raise ValueError ("Missing tune.data.tune_label in configuration" )
718737
719- # Handle glob patterns (can match multiple files)
720- label_files = sorted (glob .glob (tune_label_pattern ))
738+ # Handle both string patterns and pre-resolved lists
739+ if isinstance (tune_label_pattern , list ):
740+ # Already resolved to list of files
741+ label_files = sorted (tune_label_pattern )
742+ elif isinstance (tune_label_pattern , str ):
743+ # Glob pattern - expand it
744+ label_files = sorted (glob .glob (tune_label_pattern ))
745+ else :
746+ raise TypeError (f"tune_label must be string or list, got { type (tune_label_pattern )} " )
747+
721748 if not label_files :
722749 raise FileNotFoundError (f"No label files found matching pattern: { tune_label_pattern } " )
723750
@@ -740,10 +767,16 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
740767
741768 # Load mask if available
742769 mask = None
743- tune_mask_pattern = tune_data . get ( "tune_mask" , None )
770+ tune_mask_pattern = getattr ( tune_data , "tune_mask" , None )
744771 if tune_mask_pattern :
745- # Handle glob patterns
746- mask_files = sorted (glob .glob (tune_mask_pattern ))
772+ # Handle both string patterns and pre-resolved lists
773+ if isinstance (tune_mask_pattern , list ):
774+ mask_files = sorted (tune_mask_pattern )
775+ elif isinstance (tune_mask_pattern , str ):
776+ mask_files = sorted (glob .glob (tune_mask_pattern ))
777+ else :
778+ raise TypeError (f"tune_mask must be string or list, got { type (tune_mask_pattern )} " )
779+
747780 if not mask_files :
748781 print (f" ⚠️ No mask files found matching pattern: { tune_mask_pattern } " )
749782 else :
@@ -820,12 +853,11 @@ def load_and_apply_best_params(cfg):
820853 print (OmegaConf .to_yaml (best_params ))
821854
822855 # Apply to test.decoding config
823- # Note: test is Dict[str, Any], so we need to handle it carefully
824856 if cfg .test is None :
825- cfg .test = {}
857+ cfg .test = OmegaConf . create ({})
826858
827- if "decoding" not in cfg .test :
828- cfg .test [ " decoding" ] = []
859+ if not hasattr ( cfg . test , "decoding" ) or cfg .test . decoding is None :
860+ cfg .test . decoding = []
829861
830862 # Find the decoding function in test.decoding that matches the tuned function
831863 decoding_function = best_params .get ("decoding_function" , None )
@@ -836,24 +868,32 @@ def load_and_apply_best_params(cfg):
836868 else :
837869 # Find decoder with matching function name
838870 decoder_idx = None
839- for idx , decoder in enumerate (cfg .test ["decoding" ]):
840- if decoder .get ("name" ) == decoding_function :
871+ for idx , decoder in enumerate (cfg .test .decoding ):
872+ decoder_name = decoder .get ("name" ) if isinstance (decoder , dict ) else getattr (decoder , "name" , None )
873+ if decoder_name == decoding_function :
841874 decoder_idx = idx
842875 break
843876
844877 if decoder_idx is None :
845878 # Create new decoder entry
846- decoder_idx = len (cfg .test [ " decoding" ] )
847- cfg .test [ " decoding" ] .append ({"name" : decoding_function , "kwargs" : {}})
879+ decoder_idx = len (cfg .test . decoding )
880+ cfg .test . decoding .append ({"name" : decoding_function , "kwargs" : {}})
848881
849882 # Update parameters
850- if decoder_idx < len (cfg .test ["decoding" ]):
851- decoder = cfg .test ["decoding" ][decoder_idx ]
852- if "kwargs" not in decoder :
853- decoder ["kwargs" ] = {}
854-
855- # Apply best parameters
856- decoder ["kwargs" ].update (OmegaConf .to_container (best_params ["parameters" ]))
883+ if decoder_idx < len (cfg .test .decoding ):
884+ decoder = cfg .test .decoding [decoder_idx ]
885+
886+ # Handle both dict and config object
887+ if isinstance (decoder , dict ):
888+ if "kwargs" not in decoder :
889+ decoder ["kwargs" ] = {}
890+ decoder ["kwargs" ].update (OmegaConf .to_container (best_params ["decoding_params" ]))
891+ else :
892+ if not hasattr (decoder , "kwargs" ) or decoder .kwargs is None :
893+ decoder .kwargs = {}
894+ # Update kwargs with best parameters
895+ for key , value in best_params ["decoding_params" ].items ():
896+ decoder .kwargs [key ] = value
857897
858898 print (f"✓ Applied best parameters to test.decoding[{ decoder_idx } ]" )
859899
0 commit comments