@@ -65,7 +65,7 @@ class AutoConfigPlanner:
6565
6666 def __init__ (
6767 self ,
68- architecture : str = ' mednext' ,
68+ architecture : str = " mednext" ,
6969 target_spacing : Optional [List [float ]] = None ,
7070 median_shape : Optional [List [int ]] = None ,
7171 manual_overrides : Optional [Dict [str , Any ]] = None ,
@@ -93,33 +93,33 @@ def __init__(
9393 def _get_architecture_defaults (self ) -> Dict [str , Any ]:
9494 """Get architecture-specific default parameters."""
9595 defaults = {
96- ' mednext' : {
97- ' base_features' : 32 ,
98- ' max_features' : 320 ,
99- 'lr' : 1e-3 , # MedNeXt paper recommends 1e-3
100- ' use_scheduler' : False , # MedNeXt uses constant LR
96+ " mednext" : {
97+ " base_features" : 32 ,
98+ " max_features" : 320 ,
99+ "lr" : 1e-3 , # MedNeXt paper recommends 1e-3
100+ " use_scheduler" : False , # MedNeXt uses constant LR
101101 },
102- ' mednext_custom' : {
103- ' base_features' : 32 ,
104- ' max_features' : 320 ,
105- 'lr' : 1e-3 ,
106- ' use_scheduler' : False ,
102+ " mednext_custom" : {
103+ " base_features" : 32 ,
104+ " max_features" : 320 ,
105+ "lr" : 1e-3 ,
106+ " use_scheduler" : False ,
107107 },
108- ' monai_basic_unet3d' : {
109- ' base_features' : 32 ,
110- ' max_features' : 512 ,
111- 'lr' : 1e-4 ,
112- ' use_scheduler' : True ,
108+ " monai_basic_unet3d" : {
109+ " base_features" : 32 ,
110+ " max_features" : 512 ,
111+ "lr" : 1e-4 ,
112+ " use_scheduler" : True ,
113113 },
114- ' monai_unet' : {
115- ' base_features' : 32 ,
116- ' max_features' : 512 ,
117- 'lr' : 1e-4 ,
118- ' use_scheduler' : True ,
114+ " monai_unet" : {
115+ " base_features" : 32 ,
116+ " max_features" : 512 ,
117+ "lr" : 1e-4 ,
118+ " use_scheduler" : True ,
119119 },
120120 }
121121
122- return defaults .get (self .architecture , defaults [' monai_basic_unet3d' ])
122+ return defaults .get (self .architecture , defaults [" monai_basic_unet3d" ])
123123
124124 def plan (
125125 self ,
@@ -149,20 +149,20 @@ def plan(
149149 result .planning_notes .append (f"Patch size: { patch_size } " )
150150
151151 # Step 2: Get model parameters
152- result .base_features = self .arch_defaults [' base_features' ]
153- result .max_features = self .arch_defaults [' max_features' ]
152+ result .base_features = self .arch_defaults [" base_features" ]
153+ result .max_features = self .arch_defaults [" max_features" ]
154154
155155 # Step 3: Determine precision
156156 result .precision = "16-mixed" if use_mixed_precision else "32"
157157
158158 # Step 4: Estimate memory and determine batch size
159- if not self .gpu_info [' cuda_available' ]:
159+ if not self .gpu_info [" cuda_available" ]:
160160 result .batch_size = 1
161161 result .precision = "32" # CPU doesn't support mixed precision well
162162 result .warnings .append ("CUDA not available, using CPU with batch_size=1" )
163163 result .planning_notes .append ("Training on CPU (slow!)" )
164164 else :
165- gpu_memory_gb = self .gpu_info [' available_memory_gb' ][0 ] # Use first GPU
165+ gpu_memory_gb = self .gpu_info [" available_memory_gb" ][0 ] # Use first GPU
166166 result .available_gpu_memory_gb = gpu_memory_gb
167167
168168 # Calculate number of pooling stages (log2 of patch size / 4)
@@ -199,24 +199,24 @@ def plan(
199199 )
200200 result .planning_notes .append (
201201 f"Estimated memory: { result .estimated_gpu_memory_gb :.2f} GB "
202- f"({ result .estimated_gpu_memory_gb / gpu_memory_gb * 100 :.1f} % of GPU)"
202+ f"({ result .estimated_gpu_memory_gb / gpu_memory_gb * 100 :.1f} % of GPU)"
203203 )
204204 result .planning_notes .append (f"Batch size: { batch_size } " )
205205
206206 # Gradient accumulation if batch size is very small
207207 if batch_size == 1 :
208208 result .accumulate_grad_batches = 4
209209 result .planning_notes .append (
210- f "Using gradient accumulation (4 batches) for effective batch_size=4"
210+ "Using gradient accumulation (4 batches) for effective batch_size=4"
211211 )
212212
213213 # Step 5: Determine num_workers
214- num_gpus = self .gpu_info [' num_gpus' ] if self .gpu_info [' cuda_available' ] else 0
214+ num_gpus = self .gpu_info [" num_gpus" ] if self .gpu_info [" cuda_available" ] else 0
215215 result .num_workers = get_optimal_num_workers (num_gpus )
216216 result .planning_notes .append (f"Num workers: { result .num_workers } " )
217217
218218 # Step 6: Learning rate
219- result .lr = self .arch_defaults ['lr' ]
219+ result .lr = self .arch_defaults ["lr" ]
220220 result .planning_notes .append (f"Learning rate: { result .lr } " )
221221
222222 # Step 7: Apply manual overrides
@@ -266,8 +266,8 @@ def _plan_patch_size(self) -> List[int]:
266266
267267 # If GPU memory is limited, may need to reduce patch size
268268 # (This is a simplified heuristic)
269- if self .gpu_info [' cuda_available' ]:
270- gpu_memory_gb = self .gpu_info [' available_memory_gb' ][0 ]
269+ if self .gpu_info [" cuda_available" ]:
270+ gpu_memory_gb = self .gpu_info [" available_memory_gb" ][0 ]
271271 if gpu_memory_gb < 8 :
272272 # Very limited GPU, use smaller patches
273273 patch_size = np .minimum (patch_size , [64 , 64 , 64 ])
@@ -289,8 +289,10 @@ def print_plan(self, result: AutoPlanResult):
289289 print (f" Batch Size: { result .batch_size } " )
290290 if result .accumulate_grad_batches > 1 :
291291 effective_bs = result .batch_size * result .accumulate_grad_batches
292- print (f" Gradient Accumulation: { result .accumulate_grad_batches } "
293- f"(effective batch_size={ effective_bs } )" )
292+ print (
293+ f" Gradient Accumulation: { result .accumulate_grad_batches } "
294+ f"(effective batch_size={ effective_bs } )"
295+ )
294296 print (f" Num Workers: { result .num_workers } " )
295297 print ()
296298
@@ -307,8 +309,10 @@ def print_plan(self, result: AutoPlanResult):
307309 if result .available_gpu_memory_gb > 0 :
308310 print ("💾 GPU Memory:" )
309311 print (f" Available: { result .available_gpu_memory_gb :.2f} GB" )
310- print (f" Estimated Usage: { result .estimated_gpu_memory_gb :.2f} GB "
311- f"({ result .estimated_gpu_memory_gb / result .available_gpu_memory_gb * 100 :.1f} %)" )
312+ print (
313+ f" Estimated Usage: { result .estimated_gpu_memory_gb :.2f} GB "
314+ f"({ result .estimated_gpu_memory_gb / result .available_gpu_memory_gb * 100 :.1f} %)"
315+ )
312316 print (f" Per Sample: { result .gpu_memory_per_sample_gb :.2f} GB" )
313317 print ()
314318
@@ -349,45 +353,50 @@ def auto_plan_config(
349353 Updated config with auto-planned parameters
350354 """
351355 # Check if auto-planning is disabled
352- if hasattr (config , ' system' ) and hasattr (config .system , ' auto_plan' ):
356+ if hasattr (config , " system" ) and hasattr (config .system , " auto_plan" ):
353357 if not config .system .auto_plan :
354358 print ("ℹ️ Auto-planning disabled in config" )
355359 return config
356360
357361 # Extract relevant config values
358- architecture = config .model .architecture if hasattr (config .model , 'architecture' ) else 'mednext'
359- in_channels = config .model .in_channels if hasattr (config .model , 'in_channels' ) else 1
360- out_channels = config .model .out_channels if hasattr (config .model , 'out_channels' ) else 2
361- deep_supervision = config .model .deep_supervision if hasattr (config .model , 'deep_supervision' ) else False
362+ architecture = config .model .architecture if hasattr (config .model , "architecture" ) else "mednext"
363+ in_channels = config .model .in_channels if hasattr (config .model , "in_channels" ) else 1
364+ out_channels = config .model .out_channels if hasattr (config .model , "out_channels" ) else 2
365+ deep_supervision = (
366+ config .model .deep_supervision if hasattr (config .model , "deep_supervision" ) else False
367+ )
362368
363369 # Get target spacing and median shape if provided
364370 target_spacing = None
365- if hasattr (config , ' data' ) and hasattr (config .data , ' target_spacing' ):
371+ if hasattr (config , " data" ) and hasattr (config .data , " target_spacing" ):
366372 target_spacing = config .data .target_spacing
367373
368374 median_shape = None
369- if hasattr (config , ' data' ) and hasattr (config .data , ' median_shape' ):
375+ if hasattr (config , " data" ) and hasattr (config .data , " median_shape" ):
370376 median_shape = config .data .median_shape
371377
372378 # Collect manual overrides (values explicitly set in config)
373379 manual_overrides = {}
374- if hasattr (config , 'data' ):
375- if hasattr (config .data , 'batch_size' ) and config .data .batch_size is not None :
376- manual_overrides ['batch_size' ] = config .data .batch_size
377- if hasattr (config .data , 'num_workers' ) and config .data .num_workers is not None :
378- manual_overrides ['num_workers' ] = config .data .num_workers
379- if hasattr (config .data , 'patch_size' ) and config .data .patch_size is not None :
380- manual_overrides ['patch_size' ] = config .data .patch_size
381-
382- if hasattr (config , 'training' ):
383- if hasattr (config .training , 'precision' ) and config .training .precision is not None :
384- manual_overrides ['precision' ] = config .training .precision
385- if hasattr (config .training , 'accumulate_grad_batches' ) and config .training .accumulate_grad_batches is not None :
386- manual_overrides ['accumulate_grad_batches' ] = config .training .accumulate_grad_batches
387-
388- if hasattr (config , 'optimizer' ):
389- if hasattr (config .optimizer , 'lr' ) and config .optimizer .lr is not None :
390- manual_overrides ['lr' ] = config .optimizer .lr
380+ training_cfg = getattr (config .system , "training" , None ) if hasattr (config , "system" ) else None
381+ if hasattr (config , "data" ):
382+ if training_cfg and getattr (training_cfg , "batch_size" , None ) is not None :
383+ manual_overrides ["batch_size" ] = training_cfg .batch_size
384+ if training_cfg and getattr (training_cfg , "num_workers" , None ) is not None :
385+ manual_overrides ["num_workers" ] = training_cfg .num_workers
386+ if hasattr (config .data , "patch_size" ) and config .data .patch_size is not None :
387+ manual_overrides ["patch_size" ] = config .data .patch_size
388+
389+ if hasattr (config , "optimization" ):
390+ if getattr (config .optimization , "precision" , None ) is not None :
391+ manual_overrides ["precision" ] = config .optimization .precision
392+ if getattr (config .optimization , "accumulate_grad_batches" , None ) is not None :
393+ manual_overrides ["accumulate_grad_batches" ] = (
394+ config .optimization .accumulate_grad_batches
395+ )
396+
397+ opt_cfg = getattr (config .optimization , "optimizer" , None )
398+ if opt_cfg and getattr (opt_cfg , "lr" , None ) is not None :
399+ manual_overrides ["lr" ] = opt_cfg .lr
391400
392401 # Create planner
393402 planner = AutoConfigPlanner (
@@ -398,9 +407,9 @@ def auto_plan_config(
398407 )
399408
400409 # Plan
401- use_mixed_precision = not (hasattr ( config , 'training' ) and
402- hasattr (config .training , ' precision' ) and
403- config . training . precision == "32" )
410+ use_mixed_precision = not (
411+ hasattr ( config , "optimization" ) and getattr (config .optimization , " precision" , None ) == "32"
412+ )
404413
405414 result = planner .plan (
406415 in_channels = in_channels ,
@@ -412,20 +421,20 @@ def auto_plan_config(
412421 # Update config with planned values (if not manually overridden)
413422 OmegaConf .set_struct (config , False ) # Allow adding new fields
414423
415- if ' batch_size' not in manual_overrides :
416- config . data .batch_size = result .batch_size
417- if ' num_workers' not in manual_overrides :
418- config . data .num_workers = result .num_workers
419- if ' patch_size' not in manual_overrides :
424+ if " batch_size" not in manual_overrides and training_cfg is not None :
425+ training_cfg .batch_size = result .batch_size
426+ if " num_workers" not in manual_overrides and training_cfg is not None :
427+ training_cfg .num_workers = result .num_workers
428+ if " patch_size" not in manual_overrides :
420429 config .data .patch_size = result .patch_size
421430
422- if ' precision' not in manual_overrides :
423- config .training .precision = result .precision
424- if ' accumulate_grad_batches' not in manual_overrides :
425- config .training .accumulate_grad_batches = result .accumulate_grad_batches
431+ if " precision" not in manual_overrides :
432+ config .optimization .precision = result .precision
433+ if " accumulate_grad_batches" not in manual_overrides :
434+ config .optimization .accumulate_grad_batches = result .accumulate_grad_batches
426435
427- if 'lr' not in manual_overrides :
428- config .optimizer .lr = result .lr
436+ if "lr" not in manual_overrides and hasattr ( config , "optimization" ) :
437+ config .optimization . optimizer .lr = result .lr
429438
430439 OmegaConf .set_struct (config , True ) # Re-enable struct mode
431440
@@ -436,21 +445,20 @@ def auto_plan_config(
436445 return config
437446
438447
439- if __name__ == ' __main__' :
448+ if __name__ == " __main__" :
440449 # Test auto planning
441450 from connectomics .config import Config
442- from omegaconf import OmegaConf
443451
444452 # Create test config
445453 cfg = OmegaConf .structured (Config ())
446- cfg .model .architecture = ' mednext'
454+ cfg .model .architecture = " mednext"
447455 cfg .model .deep_supervision = True
448456
449457 # Auto plan
450458 cfg = auto_plan_config (cfg , print_results = True )
451459
452460 print ("\n Final Config Values:" )
453- print (f" batch_size: { cfg .data .batch_size } " )
461+ print (f" batch_size: { cfg .system . training .batch_size } " )
454462 print (f" patch_size: { cfg .data .patch_size } " )
455463 print (f" precision: { cfg .optimization .precision } " )
456464 print (f" lr: { cfg .optimization .optimizer .lr } " )
0 commit comments