Skip to content

Commit 0f69540

Browse files
committed
Fixes and updates for neurons and axons examples
1 parent 89937ca commit 0f69540

File tree

9 files changed

+373
-142
lines changed

9 files changed

+373
-142
lines changed

connectomics/data/dataset/dataset_base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,24 @@ def __init__(
214214
self.dataset_length = len(data_dicts)
215215

216216
def __len__(self) -> int:
217+
"""
218+
Return dataset length.
219+
220+
For CacheDataset with cache_rate < 1.0, we must return the actual
221+
number of cached items, not the requested iter_num, to avoid IndexError.
222+
"""
223+
# If using partial caching, return the actual cached data length
224+
# CacheDataset stores cached indices in self._cache
225+
if hasattr(self, '_cache') and len(self._cache) < len(self.data):
226+
# Partial caching: return cached length for validation
227+
# For training with iter_num, we still want to iterate iter_num times
228+
if self.mode == 'train' and self.iter_num > 0:
229+
return self.dataset_length
230+
else:
231+
# For validation/test, only iterate over cached items
232+
return len(self._cache)
233+
234+
# Full caching or no caching: use dataset_length
217235
return self.dataset_length
218236

219237

connectomics/decoding/optuna_tuner.py

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,22 @@ def _load_data(self, data: np.ndarray | str | Path, name: str) -> np.ndarray:
125125

126126
def _validate_data(self):
127127
"""Validate data shapes and types."""
128+
# Handle 2D data: (C, H, W) → (C, 1, H, W)
129+
if self.predictions.ndim == 3:
130+
print(f" 📐 2D data detected, expanding predictions: {self.predictions.shape}{self.predictions.shape[:1] + (1,) + self.predictions.shape[1:]}")
131+
self.predictions = self.predictions[:, np.newaxis, :, :]
132+
128133
# Predictions should be (C, D, H, W)
129134
if self.predictions.ndim != 4:
130135
raise ValueError(
131136
f"Predictions should be 4D (C, D, H, W), got shape {self.predictions.shape}"
132137
)
133138

139+
# Handle 2D ground truth: (H, W) → (1, H, W)
140+
if self.ground_truth.ndim == 2:
141+
print(f" 📐 2D ground truth detected, expanding: {self.ground_truth.shape}{(1,) + self.ground_truth.shape}")
142+
self.ground_truth = self.ground_truth[np.newaxis, :, :]
143+
134144
# Ground truth should be (D, H, W)
135145
if self.ground_truth.ndim != 3:
136146
raise ValueError(
@@ -145,8 +155,12 @@ def _validate_data(self):
145155
f"ground_truth {self.ground_truth.shape}"
146156
)
147157

148-
# Check mask if provided
158+
# Handle 2D mask if provided
149159
if self.mask is not None:
160+
if self.mask.ndim == 2:
161+
print(f" 📐 2D mask detected, expanding: {self.mask.shape}{(1,) + self.mask.shape}")
162+
self.mask = self.mask[np.newaxis, :, :]
163+
150164
if self.mask.shape != self.ground_truth.shape:
151165
raise ValueError(
152166
f"Mask shape {self.mask.shape} doesn't match "
@@ -170,7 +184,7 @@ def optimize(self) -> optuna.Study:
170184
direction = self._get_optimization_direction()
171185

172186
# Create storage directory if using SQLite
173-
storage = self.tune_cfg.get("storage", None)
187+
storage = getattr(self.tune_cfg, "storage", None)
174188
if storage and storage.startswith("sqlite:///"):
175189
# Extract database file path from SQLite URL
176190
db_path = storage.replace("sqlite:///", "")
@@ -204,7 +218,7 @@ def optimize(self) -> optuna.Study:
204218
self._objective,
205219
n_trials=n_trials,
206220
timeout=timeout,
207-
show_progress_bar=self.tune_cfg.logging.get("show_progress_bar", True),
221+
show_progress_bar=getattr(self.tune_cfg.logging, "show_progress_bar", True),
208222
)
209223

210224
# Print results
@@ -219,7 +233,7 @@ def _create_sampler(self) -> optuna.samplers.BaseSampler:
219233
"""Create Optuna sampler from config."""
220234
sampler_cfg = self.tune_cfg.sampler
221235
sampler_name = sampler_cfg["name"]
222-
sampler_kwargs = sampler_cfg.get("kwargs", {})
236+
sampler_kwargs = getattr(sampler_cfg, "kwargs", {})
223237

224238
# Convert OmegaConf to dict
225239
if isinstance(sampler_kwargs, DictConfig):
@@ -236,13 +250,13 @@ def _create_sampler(self) -> optuna.samplers.BaseSampler:
236250

237251
def _create_pruner(self) -> Optional[optuna.pruners.BasePruner]:
238252
"""Create Optuna pruner from config."""
239-
pruner_cfg = self.tune_cfg.get("pruner", None)
253+
pruner_cfg = getattr(self.tune_cfg, "pruner", None)
240254

241-
if pruner_cfg is None or not pruner_cfg.get("enabled", False):
255+
if pruner_cfg is None or not getattr(pruner_cfg, "enabled", False):
242256
return None
243257

244-
pruner_name = pruner_cfg.get("name", "Median")
245-
pruner_kwargs = pruner_cfg.get("kwargs", {})
258+
pruner_name = getattr(pruner_cfg, "name", "Median")
259+
pruner_kwargs = getattr(pruner_cfg, "kwargs", {})
246260

247261
# Convert OmegaConf to dict
248262
if isinstance(pruner_kwargs, DictConfig):
@@ -338,7 +352,7 @@ def _objective(self, trial: optuna.Trial) -> float:
338352
)
339353

340354
# Print progress
341-
if self.tune_cfg.logging.get("verbose", True):
355+
if getattr(self.tune_cfg.logging, "verbose", True):
342356
print(f"Trial {self.trial_count:3d}: {metric_name}={metric_value:.4f}")
343357

344358
return metric_value
@@ -546,7 +560,7 @@ def _print_results(self, study: optuna.Study):
546560
for key, value in best_decoding_params.items():
547561
print(f" {key}: {value}")
548562

549-
if self.param_space_cfg.get("postprocessing", {}).get("enabled", False):
563+
if getattr(self.param_space_cfg, "postprocessing", None) and getattr(self.param_space_cfg.postprocessing, "enabled", False):
550564
best_postproc_params = self._reconstruct_postproc_params(study.best_params)
551565
if best_postproc_params:
552566
print(f"\n Post-processing params:")
@@ -662,8 +676,13 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
662676
print("\n[1/4] Running inference on tuning dataset...")
663677

664678
# Get tune config sections (used later for loading predictions, ground truth, masks)
665-
tune_data = cfg.tune.get("data", {})
666-
tune_output = cfg.tune.get("output", {})
679+
tune_data = getattr(cfg.tune, "data", None)
680+
tune_output = getattr(cfg.tune, "output", None)
681+
682+
if tune_data is None:
683+
raise ValueError("Missing tune.data in configuration")
684+
if tune_output is None:
685+
raise ValueError("Missing tune.output in configuration")
667686

668687
# Create datamodule with tune mode (reads from cfg.tune.data)
669688
# Uses inference settings from cfg.inference (sliding window, TTA, save_predictions, etc.)
@@ -677,8 +696,8 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
677696

678697
# Step 2: Load predictions from saved files
679698
print("\n[2/4] Loading predictions from saved files...")
680-
output_pred_dir = tune_output.get("output_pred", str(output_dir.parent / "results"))
681-
cache_suffix = tune_output.get("cache_suffix", "_tta_prediction.h5")
699+
output_pred_dir = getattr(tune_output, "output_pred", str(output_dir.parent / "results"))
700+
cache_suffix = getattr(tune_output, "cache_suffix", "_tta_prediction.h5")
682701
predictions_dir = Path(output_pred_dir)
683702

684703
# Find all prediction files using cache_suffix from config
@@ -711,13 +730,21 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
711730

712731
# Step 3: Load ground truth
713732
print("\n[3/4] Loading ground truth labels...")
714-
tune_label_pattern = tune_data.get("tune_label", None)
733+
tune_label_pattern = getattr(tune_data, "tune_label", None)
715734

716735
if tune_label_pattern is None:
717736
raise ValueError("Missing tune.data.tune_label in configuration")
718737

719-
# Handle glob patterns (can match multiple files)
720-
label_files = sorted(glob.glob(tune_label_pattern))
738+
# Handle both string patterns and pre-resolved lists
739+
if isinstance(tune_label_pattern, list):
740+
# Already resolved to list of files
741+
label_files = sorted(tune_label_pattern)
742+
elif isinstance(tune_label_pattern, str):
743+
# Glob pattern - expand it
744+
label_files = sorted(glob.glob(tune_label_pattern))
745+
else:
746+
raise TypeError(f"tune_label must be string or list, got {type(tune_label_pattern)}")
747+
721748
if not label_files:
722749
raise FileNotFoundError(f"No label files found matching pattern: {tune_label_pattern}")
723750

@@ -740,10 +767,16 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
740767

741768
# Load mask if available
742769
mask = None
743-
tune_mask_pattern = tune_data.get("tune_mask", None)
770+
tune_mask_pattern = getattr(tune_data, "tune_mask", None)
744771
if tune_mask_pattern:
745-
# Handle glob patterns
746-
mask_files = sorted(glob.glob(tune_mask_pattern))
772+
# Handle both string patterns and pre-resolved lists
773+
if isinstance(tune_mask_pattern, list):
774+
mask_files = sorted(tune_mask_pattern)
775+
elif isinstance(tune_mask_pattern, str):
776+
mask_files = sorted(glob.glob(tune_mask_pattern))
777+
else:
778+
raise TypeError(f"tune_mask must be string or list, got {type(tune_mask_pattern)}")
779+
747780
if not mask_files:
748781
print(f" ⚠️ No mask files found matching pattern: {tune_mask_pattern}")
749782
else:
@@ -820,12 +853,11 @@ def load_and_apply_best_params(cfg):
820853
print(OmegaConf.to_yaml(best_params))
821854

822855
# Apply to test.decoding config
823-
# Note: test is Dict[str, Any], so we need to handle it carefully
824856
if cfg.test is None:
825-
cfg.test = {}
857+
cfg.test = OmegaConf.create({})
826858

827-
if "decoding" not in cfg.test:
828-
cfg.test["decoding"] = []
859+
if not hasattr(cfg.test, "decoding") or cfg.test.decoding is None:
860+
cfg.test.decoding = []
829861

830862
# Find the decoding function in test.decoding that matches the tuned function
831863
decoding_function = best_params.get("decoding_function", None)
@@ -836,24 +868,32 @@ def load_and_apply_best_params(cfg):
836868
else:
837869
# Find decoder with matching function name
838870
decoder_idx = None
839-
for idx, decoder in enumerate(cfg.test["decoding"]):
840-
if decoder.get("name") == decoding_function:
871+
for idx, decoder in enumerate(cfg.test.decoding):
872+
decoder_name = decoder.get("name") if isinstance(decoder, dict) else getattr(decoder, "name", None)
873+
if decoder_name == decoding_function:
841874
decoder_idx = idx
842875
break
843876

844877
if decoder_idx is None:
845878
# Create new decoder entry
846-
decoder_idx = len(cfg.test["decoding"])
847-
cfg.test["decoding"].append({"name": decoding_function, "kwargs": {}})
879+
decoder_idx = len(cfg.test.decoding)
880+
cfg.test.decoding.append({"name": decoding_function, "kwargs": {}})
848881

849882
# Update parameters
850-
if decoder_idx < len(cfg.test["decoding"]):
851-
decoder = cfg.test["decoding"][decoder_idx]
852-
if "kwargs" not in decoder:
853-
decoder["kwargs"] = {}
854-
855-
# Apply best parameters
856-
decoder["kwargs"].update(OmegaConf.to_container(best_params["parameters"]))
883+
if decoder_idx < len(cfg.test.decoding):
884+
decoder = cfg.test.decoding[decoder_idx]
885+
886+
# Handle both dict and config object
887+
if isinstance(decoder, dict):
888+
if "kwargs" not in decoder:
889+
decoder["kwargs"] = {}
890+
decoder["kwargs"].update(OmegaConf.to_container(best_params["decoding_params"]))
891+
else:
892+
if not hasattr(decoder, "kwargs") or decoder.kwargs is None:
893+
decoder.kwargs = {}
894+
# Update kwargs with best parameters
895+
for key, value in best_params["decoding_params"].items():
896+
decoder.kwargs[key] = value
857897

858898
print(f"✓ Applied best parameters to test.decoding[{decoder_idx}]")
859899

connectomics/training/lit/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,11 @@ def setup(self, stage=None):
637637
)
638638
else:
639639
# Standard data module
640-
use_cache = cfg.data.use_cache
640+
# Disable caching for test/tune modes to avoid issues with partial cache returning 0 length
641+
use_cache = cfg.data.use_cache and mode == "train"
642+
643+
if mode in ["test", "tune"] and cfg.data.use_cache:
644+
print(" ⚠️ Caching disabled for test/tune mode (incompatible with partial cache)")
641645

642646
# Note: transpose_axes is now handled in the transform builders (build_train/val/test_transforms)
643647
# which embed the transpose in LoadVolumed, so no need to pass it here

connectomics/training/lit/model.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,12 @@ def _invert_save_prediction_transform(self, data: np.ndarray) -> np.ndarray:
214214
data = data.astype(np.float32)
215215

216216
# Invert the scaling if it was applied
217-
if intensity_scale is not None and intensity_scale != 1.0:
217+
# Note: intensity_scale < 0 means scaling was disabled, so no inversion needed
218+
if intensity_scale is not None and intensity_scale > 0 and intensity_scale != 1.0:
218219
data = data / float(intensity_scale)
219220
print(f" 🔄 Inverted intensity scaling by {intensity_scale}")
221+
elif intensity_scale is not None and intensity_scale < 0:
222+
print(f" ℹ️ Intensity scaling was disabled (scale={intensity_scale}), no inversion needed")
220223

221224
return data
222225

@@ -296,14 +299,26 @@ def _compute_test_metrics(self, decoded_predictions: np.ndarray, labels: torch.T
296299
pred_tensor = torch.from_numpy(decoded_predictions).float().to(self.device)
297300
labels_tensor = labels.float()
298301

302+
# Remove batch and channel dimensions
299303
pred_tensor = pred_tensor.squeeze()
300304
labels_tensor = labels_tensor.squeeze()
301305

302-
if pred_tensor.ndim != labels_tensor.ndim:
303-
if pred_tensor.ndim == labels_tensor.ndim - 1:
304-
pred_tensor = pred_tensor.unsqueeze(0)
305-
elif labels_tensor.ndim == pred_tensor.ndim - 1:
306-
labels_tensor = labels_tensor.unsqueeze(0)
306+
# Ensure both tensors have the same shape
307+
if pred_tensor.shape != labels_tensor.shape:
308+
print(f" ⚠️ Shape mismatch: pred={pred_tensor.shape}, labels={labels_tensor.shape}")
309+
310+
# Try to align dimensions
311+
if pred_tensor.ndim != labels_tensor.ndim:
312+
if pred_tensor.ndim == labels_tensor.ndim - 1:
313+
pred_tensor = pred_tensor.unsqueeze(0)
314+
elif labels_tensor.ndim == pred_tensor.ndim - 1:
315+
labels_tensor = labels_tensor.unsqueeze(0)
316+
317+
# If still mismatched after dimension alignment, skip metrics
318+
if pred_tensor.shape != labels_tensor.shape:
319+
print(f" ❌ Cannot compute metrics: incompatible shapes after alignment")
320+
print(f" pred={pred_tensor.shape}, labels={labels_tensor.shape}")
321+
return
307322

308323
if pred_tensor.max() <= 1.0:
309324
pred_binary = (pred_tensor > 0.5).long()
@@ -548,17 +563,42 @@ def configure_optimizers(self) -> Dict[str, Any]:
548563
"""Configure optimizers and learning rate schedulers."""
549564
optimizer = build_optimizer(self.cfg, self.model)
550565

551-
# Build scheduler if configured
552-
if hasattr(self.cfg, 'scheduler') and self.cfg.scheduler is not None:
566+
# Build scheduler if configured (check both cfg.scheduler and cfg.optimization.scheduler)
567+
has_scheduler = (
568+
(hasattr(self.cfg, 'scheduler') and self.cfg.scheduler is not None) or
569+
(hasattr(self.cfg, 'optimization') and hasattr(self.cfg.optimization, 'scheduler') and self.cfg.optimization.scheduler is not None)
570+
)
571+
572+
if has_scheduler:
553573
scheduler = build_lr_scheduler(self.cfg, optimizer)
554574

575+
# Check if this is ReduceLROnPlateau (requires metric monitoring)
576+
scheduler_config = {
577+
'scheduler': scheduler,
578+
'interval': 'epoch',
579+
'frequency': 1,
580+
}
581+
582+
# ReduceLROnPlateau requires the 'monitor' key to pass the metric value
583+
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
584+
# Get monitor metric from scheduler config
585+
monitor_metric = None
586+
if hasattr(self.cfg, 'optimization') and hasattr(self.cfg.optimization, 'scheduler'):
587+
monitor_metric = getattr(self.cfg.optimization.scheduler, 'monitor', None)
588+
elif hasattr(self.cfg, 'scheduler'):
589+
monitor_metric = getattr(self.cfg.scheduler, 'monitor', None)
590+
591+
if monitor_metric:
592+
scheduler_config['monitor'] = monitor_metric
593+
print(f" ✅ ReduceLROnPlateau will monitor: {monitor_metric}")
594+
else:
595+
# Default to validation loss
596+
scheduler_config['monitor'] = 'val_loss_total'
597+
print(f" ⚠️ ReduceLROnPlateau will monitor: val_loss_total (default, no monitor specified in config)")
598+
555599
return {
556600
'optimizer': optimizer,
557-
'lr_scheduler': {
558-
'scheduler': scheduler,
559-
'interval': 'epoch',
560-
'frequency': 1,
561-
},
601+
'lr_scheduler': scheduler_config,
562602
}
563603
else:
564604
return optimizer

0 commit comments

Comments
 (0)