Skip to content

Commit f1b35ec

Browse files
authored
Merge pull request #179 from mansouralawi/neurons_axons_examples
Fixes neurons and axons examples
2 parents 627fecd + bcf8310 commit f1b35ec

File tree

9 files changed

+574
-776
lines changed

9 files changed

+574
-776
lines changed

connectomics/data/dataset/dataset_base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,24 @@ def __init__(
214214
self.dataset_length = len(data_dicts)
215215

216216
def __len__(self) -> int:
217+
"""
218+
Return dataset length.
219+
220+
For CacheDataset with cache_rate < 1.0, we must return the actual
221+
number of cached items, not the requested iter_num, to avoid IndexError.
222+
"""
223+
# If using partial caching, return the actual cached data length
224+
# CacheDataset stores cached indices in self._cache
225+
if hasattr(self, '_cache') and len(self._cache) < len(self.data):
226+
# Partial caching: return cached length for validation
227+
# For training with iter_num, we still want to iterate iter_num times
228+
if self.mode == 'train' and self.iter_num > 0:
229+
return self.dataset_length
230+
else:
231+
# For validation/test, only iterate over cached items
232+
return len(self._cache)
233+
234+
# Full caching or no caching: use dataset_length
217235
return self.dataset_length
218236

219237

connectomics/decoding/optuna_tuner.py

Lines changed: 188 additions & 270 deletions
Large diffs are not rendered by default.

connectomics/training/lit/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,11 @@ def setup(self, stage=None):
638638
)
639639
else:
640640
# Standard data module
641-
use_cache = cfg.data.use_cache
641+
# Disable caching for test/tune modes to avoid issues with partial cache returning 0 length
642+
use_cache = cfg.data.use_cache and mode == "train"
643+
644+
if mode in ["test", "tune"] and cfg.data.use_cache:
645+
print(" ⚠️ Caching disabled for test/tune mode (incompatible with partial cache)")
642646

643647
# Note: transpose_axes handled in transform builders (build_train/val/test_transforms)
644648
# They embed the transpose in LoadVolumed, so no need to pass it here

connectomics/training/lit/model.py

Lines changed: 143 additions & 415 deletions
Large diffs are not rendered by default.

install.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -423,17 +423,22 @@ def install_pytorch_connectomics(
423423
print_success(f"Core packages installed: {', '.join(to_install)}")
424424
else:
425425
print_success("All core packages already installed")
426-
print_info("Ensuring numpy and h5py are installed from conda-forge (force reinstall)...")
426+
427+
# CRITICAL: Reinstall cc3d to match current numpy version
428+
# This prevents "numpy.dtype size changed" binary incompatibility errors
429+
print_info("Reinstalling cc3d to match current numpy version...")
427430
code, _, stderr = run_command(
428-
f"conda install -n {env_name} -c conda-forge numpy h5py -y --force-reinstall",
429-
check=False,
431+
f"conda run -n {env_name} pip uninstall -y connected-components-3d", check=False
432+
)
433+
code, _, stderr = run_command(
434+
f"conda run -n {env_name} pip install --no-cache-dir connected-components-3d", check=False
430435
)
431436
if code != 0:
432-
print_warning("conda reinstall of numpy/h5py failed; please verify the environment manually")
437+
print_warning("Failed to reinstall cc3d; may have binary incompatibility issues")
433438
if stderr.strip():
434439
print_warning(stderr.strip())
435440
else:
436-
print_success("numpy and h5py verified via conda-forge")
441+
print_success("cc3d reinstalled successfully")
437442

438443
# Group 2: Optional scientific packages (nice to have, but slow to install)
439444
optional_packages = ["scipy", "scikit-learn", "scikit-image", "opencv"]
@@ -507,10 +512,12 @@ def install_pytorch_connectomics(
507512
if pip_options:
508513
pip_cmd += f" {pip_options}"
509514

510-
code, _, stderr = run_command(f"{pip_cmd} --no-build-isolation", check=False)
515+
# First try without --no-build-isolation to ensure dependencies are installed
516+
print_info("Installing with full dependency resolution...")
517+
code, _, stderr = run_command(pip_cmd, check=False)
511518
if code != 0:
512-
print_warning("Installation with --no-build-isolation failed, retrying without it...")
513-
code, _, stderr = run_command(pip_cmd, check=False)
519+
print_warning("Standard installation failed, trying with --no-build-isolation...")
520+
code, _, stderr = run_command(f"{pip_cmd} --no-build-isolation", check=False)
514521
if code != 0:
515522
print_error(f"Failed to install PyTorch Connectomics: {stderr}")
516523
return False

justfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,17 @@ train-cellmap dataset *ARGS='':
8888
# Shows all runs (timestamped directories) for comparison
8989
# Usage: just tensorboard experiment [port] (default port: 6006)
9090
tensorboard experiment port='6006':
91-
tensorboard --logdir outputs/{{experiment}} --port {{port}}
91+
tensorboard --logdir /orcd/scratch/bcs/002/mansour/zebrafish_seg_dataset_training/outputs/{{experiment}} --port {{port}}
9292

9393
# Launch TensorBoard for all experiments
9494
# Usage: just tensorboard-all [port] (default port: 6006)
9595
tensorboard-all port='6006':
96-
tensorboard --logdir outputs/ --port {{port}}
96+
tensorboard --logdir /orcd/scratch/bcs/002/mansour/zebrafish_seg_dataset_training/outputs/ --port {{port}}
9797

9898
# Launch TensorBoard for a specific run (e.g., just tensorboard-run lucchi_monai_unet 20250203_143052)
9999
# Usage: just tensorboard-run experiment timestamp [port] (default port: 6006)
100100
tensorboard-run experiment timestamp port='6006':
101-
tensorboard --logdir outputs/{{experiment}}/{{timestamp}}/logs --port {{port}}
101+
tensorboard --logdir /orcd/scratch/bcs/002/mansour/zebrafish_seg_dataset_training/outputs/{{experiment}}/{{timestamp}} --port {{port}}
102102

103103
# Launch any just command on SLURM (e.g., just slurm short 8 4 "train lucchi")
104104
# Optional 5th parameter: GPU type (vr80g, vr40g, vr16g for V100s)

scripts/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def main():
250250

251251
# Handle tune modes
252252
if args.mode in ["tune", "tune-test"]:
253-
# Check if tune config exists (tune is TuneConfig dataclass)
254-
if cfg.tune is None or cfg.tune.parameter_space is None:
253+
# Check if tune config exists and has parameter_space
254+
if cfg.tune is None or not hasattr(cfg.tune, "parameter_space"):
255255
raise ValueError("Missing tune or tune.parameter_space configuration")
256256

257257
from connectomics.decoding import run_tuning

tutorials/monai_tsai.yaml

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ description: 3D axon segmentation using MONAI Residual UNet with paired data tra
2323
# System - Optimized for 2D training
2424
system:
2525
training:
26-
num_gpus: 4 # Single GPU
26+
num_gpus: 1 # Single GPU per task (SLURM handles multi-GPU via DDP)
2727
num_cpus: 8 # Increase for better data loading
2828
num_workers: 8 # Parallel data loading (2D slices are lighter)
2929
batch_size: 8 # Higher batch size for 2D (vs 4 for 3D)
@@ -50,25 +50,32 @@ model:
5050
dropout: 0.1 # Dropout for regularization
5151

5252
# Loss configuration - Dice for overlap, BCE for pixel-wise accuracy
53-
loss_functions: [WeightedBCEWithLogitsLoss, DiceLoss]
53+
loss_functions: [WeightedBCE, DiceLoss]
5454
loss_weights: [1.0, 1.0] # Equal weighting for BCE and Dice
5555
loss_kwargs:
56-
- {reduction: mean} # WeightedBCEWithLogitsLoss: average over batch
57-
- {include_background: true, sigmoid: true, smooth_nr: 1e-5, smooth_dr: 1e-5} # DiceLoss with sigmoid
58-
59-
# Data - Using automatic 80/20 train/val split (DeepEM-style)
60-
data:
61-
# Volume configuration
62-
train_image: datasets/axon_data_30pc_subset/training/training-original/volumes/*.tiff
63-
train_label: datasets/axon_data_30pc_subset/training/training-original/labels/*.tiff
64-
train_resolution: [5, 5] # Lucchi EM: 5nm isotropic resolution
56+
- {reduction: mean}
57+
- {include_background: true, sigmoid: true, smooth_nr: 1e-5, smooth_dr: 1e-5}
58+
59+
# Data - Separate training and validation datasets
60+
data:
61+
# Training data
62+
train_image: /orcd/scratch/bcs/002/mansour/trailmap_data/training/training-original/volumes/*.tiff
63+
train_label: /orcd/scratch/bcs/002/mansour/trailmap_data/training/training-original/labels-original-backup/*.tiff
64+
train_resolution: [2, 0.8, 0.8] # Resolution: z, y, x (applies to both train and val)
65+
66+
# Validation data (separate from training)
67+
val_image: /orcd/scratch/bcs/002/mansour/trailmap_data/validation/validation-original/volumes/*.tiff
68+
val_label: /orcd/scratch/bcs/002/mansour/trailmap_data/validation/validation-original/labels-original-backup/*.tiff
69+
6570
use_preloaded_cache: true # Load volumes into memory for fast training
71+
# train_val_split: 0.8 # Not needed when using separate val_image/val_label
6672

6773
# Patch configuration
6874
patch_size: [64, 64, 64] # Larger patches for better context
6975
pad_size: [0, 0, 0] # Padding for valid convolutions
7076
pad_mode: reflect # Reflection padding at boundaries
71-
iter_num_per_epoch: 1280 # 1280 random crops per epoch
77+
iter_num_per_epoch: 1280 # 1280 random crops per epoch (training)
78+
7279

7380
# Data transformation (applied to image/label/mask for spatial alignment)
7481
# NEW: Paired transforms ensure image and label stay aligned
@@ -105,6 +112,7 @@ data:
105112
# Optimizer - AdamW with optimized hyperparameters
106113
optimization:
107114
max_epochs: 1000
115+
val_check_interval: 1.0
108116
gradient_clip_val: 1.0 # Higher clip (0.5 was too aggressive)
109117
accumulate_grad_batches: 1
110118
precision: "bf16-mixed" # BFloat16 mixed precision
@@ -116,9 +124,10 @@ optimization:
116124
betas: [0.9, 0.999] # Standard Adam betas (momentum terms)
117125
eps: 1.0e-8 # Numerical stability
118126

119-
# Scheduler - Cosine annealing with warmup for smooth convergence
127+
# Scheduler - Reduce LR when validation loss plateaus
120128
scheduler:
121129
name: ReduceLROnPlateau # Reduce LR when validation loss plateaus
130+
monitor: val_loss_total # Monitor validation loss
122131
mode: min # Monitor minimum loss
123132
factor: 0.5 # Reduce LR by 50%
124133
patience: 50 # Wait 50 epochs before reducing
@@ -147,18 +156,19 @@ monitor:
147156

148157
# Checkpointing
149158
checkpoint:
159+
monitor: val_loss_total # Save best model based on validation loss
150160
mode: min
151161
save_top_k: 1
152162
save_last: true
153163
save_every_n_epochs: 10
154-
dirpath: outputs/monai_tsai/checkpoints/ # Will be dynamically set to outputs/{yaml_filename}/YYYYMMDD_HHMMSS/checkpoints/
164+
dirpath: /orcd/scratch/bcs/002/mansour/trailmap_data/outputs/monai_tsai/checkpoints/ # Will be dynamically set to outputs/{yaml_filename}/YYYYMMDD_HHMMSS/checkpoints/
155165
# checkpoint_filename: auto-generated from monitor metric (epoch={epoch:03d}-{monitor}={value:.4f})
156166
use_timestamp: true # Enable timestamped subdirectories (YYYYMMDD_HHMMSS)
157167

158168
# Early stopping - More patient for better convergence
159-
early_stopping:
169+
early_stopping:
160170
enabled: true
161-
monitor: train_loss_total_epoch
171+
monitor: val_loss_total # Monitor validation loss
162172
patience: 300 # Increased patience (was 200)
163173
mode: min
164174
min_delta: 1.0e-5 # Smaller threshold for finer convergence
@@ -169,10 +179,10 @@ monitor:
169179
# Inference - MONAI SlidingWindowInferer
170180
inference:
171181
data:
172-
test_image: datasets/axon_data_30pc_subset/validation/validation-original/volumes/*.tiff
173-
test_label: datasets/axon_data_30pc_subset/validation/validation-original/labels/*.tiff
174-
test_resolution: [5, 5]
175-
output_path: outputs/monai_tsai/results/
182+
test_image: /orcd/scratch/bcs/002/mansour/trailmap_data/testing/testing-original/volumes/*.tiff
183+
test_label: /orcd/scratch/bcs/002/mansour/trailmap_data/testing/testing-original/labels-original-backup/*.tiff
184+
test_resolution: [2, 0.8, 0.8]
185+
output_path: /orcd/scratch/bcs/002/mansour/trailmap_data/outputs/monai_tsai/results/
176186

177187
# MONAI SlidingWindowInferer parameters
178188
sliding_window:
@@ -195,11 +205,6 @@ inference:
195205
# NOTE: tta_act and tta_channel are applied even with null flip_axes (no ensemble, just activation + channel selection)
196206
# NOTE: If tta_channel selects specific channels, loss computation will be skipped (loss needs all class channels)
197207

198-
# Save intermediate predictions (before decoding/postprocessing)
199-
save_prediction:
200-
enabled: true
201-
intensity_scale: 255 # Scale predictions to [0, 255] for saving
202-
intensity_dtype: uint8 # Save as uint8
203208

204209
# Decoding: predicted feature maps to segmetnation mask (semantic or instance segmentation)
205210
decoding:
@@ -210,7 +215,16 @@ inference:
210215
connected_components:
211216
enabled: true # Enable connected components filtering
212217
remove_small: 10 # Remove small objects with size less than 10 pixels
213-
connectivity: 26 # Face connectivity (4=4-connected for 2D, 6=6-connected for 3D)
218+
connectivity: 6 # Face connectivity (4=4-connected for 2D, 6=6-connected for 3D)
219+
220+
221+
# Postprocessing configuration (applied AFTER decoding)
222+
postprocessing:
223+
224+
# Output format (intensity scaling and dtype conversion)
225+
intensity_scale: 255 # Scale predictions to [0, 255] for saving
226+
intensity_dtype: uint8 # Save as uint8
227+
214228

215229
# Evaluation
216230
evaluation:

0 commit comments

Comments
 (0)