Skip to content

Commit 627fecd

Browse files
author
Donglai Wei
committed
cellmap slurm training
1 parent 18ff577 commit 627fecd

File tree

2 files changed

+76
-3
lines changed

2 files changed

+76
-3
lines changed

justfile

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ test-with-params dataset ckpt params *ARGS='':
7676
infer dataset ckpt *ARGS='':
7777
python scripts/main.py --config tutorials/{{dataset}}.yaml --mode infer --checkpoint {{ckpt}} {{ARGS}}
7878

79+
# Train CellMap models (e.g., just train-cellmap cos7)
80+
train-cellmap dataset *ARGS='':
81+
python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_{{dataset}}.yaml {{ARGS}}
82+
7983
# ============================================================================
8084
# Monitoring Commands
8185
# ============================================================================
@@ -96,8 +100,14 @@ tensorboard-all port='6006':
96100
tensorboard-run experiment timestamp port='6006':
97101
tensorboard --logdir outputs/{{experiment}}/{{timestamp}}/logs --port {{port}}
98102

99-
# Launch any just command on SLURM (e.g., just slurm weilab 8 4 "train lucchi")
103+
# Launch any just command on SLURM (e.g., just slurm short 8 4 "train lucchi")
104+
# Optional 5th parameter: GPU type (vr80g, vr40g, vr16g for V100s)
105+
# Examples:
106+
# just slurm short 8 4 "train lucchi" # Any available GPU
107+
# just slurm short 8 4 "train lucchi" vr80g # A100 80GB
108+
# just slurm short 8 4 "train lucchi" vr40g # A100 40GB
100109
# Automatically uses srun for distributed training when num_gpu > 1
110+
# Time limits: short=12h, medium=2d, long=5d
101111
slurm partition num_cpu num_gpu cmd constraint='':
102112
#!/usr/bin/env bash
103113
# Configure for multi-GPU training with PyTorch Lightning DDP
@@ -109,6 +119,28 @@ slurm partition num_cpu num_gpu cmd constraint='':
109119
constraint_flag="--constraint={{constraint}}"
110120
fi
111121

122+
# Set time limit to partition maximum
123+
# Query SLURM for the partition's max time limit
124+
time_limit=$(sinfo -p {{partition}} -h -o "%l" | head -1)
125+
126+
# If sinfo fails or returns empty, use safe defaults
127+
if [ -z "$time_limit" ] || [ "$time_limit" = "infinite" ]; then
128+
case "{{partition}}" in
129+
short|interactive)
130+
time_limit="12:00:00"
131+
;;
132+
medium)
133+
time_limit="2-00:00:00"
134+
;;
135+
long)
136+
time_limit="5-00:00:00"
137+
;;
138+
*)
139+
time_limit="7-00:00:00" # 7 days for private partitions
140+
;;
141+
esac
142+
fi
143+
112144
sbatch --job-name="pytc_{{cmd}}" \
113145
--partition={{partition}} \
114146
--output=slurm_outputs/slurm-%j.out \
@@ -118,13 +150,54 @@ slurm partition num_cpu num_gpu cmd constraint='':
118150
--gpus-per-node={{num_gpu}} \
119151
--cpus-per-task={{num_cpu}} \
120152
--mem=32G \
121-
--time=48:00:00 \
153+
--time=$time_limit \
122154
$constraint_flag \
123155
--wrap="mkdir -p \$HOME/.just && export JUST_TEMPDIR=\$HOME/.just TMPDIR=\$HOME/.just && source /projects/weilab/weidf/lib/miniconda3/bin/activate pytc && cd $PWD && srun --ntasks={{num_gpu}} --ntasks-per-node={{num_gpu}} just {{cmd}}"
124156

125157
# Launch parameter sweep from config (e.g., just sweep tutorials/sweep_example.yaml)
126158
sweep config:
127159
python scripts/slurm_launcher.py --config {{config}}
160+
161+
# Launch arbitrary shell command on SLURM (e.g., just slurm-sh short 8 4 "python train.py" vr40g)
162+
# Unlike 'slurm', this runs the command directly without wrapping in 'just'
163+
slurm-sh partition num_cpu num_gpu cmd constraint='':
164+
#!/usr/bin/env bash
165+
constraint_flag=""
166+
if [ -n "{{constraint}}" ]; then
167+
constraint_flag="--constraint={{constraint}}"
168+
fi
169+
170+
# Set time limit to partition maximum
171+
time_limit=$(sinfo -p {{partition}} -h -o "%l" | head -1)
172+
if [ -z "$time_limit" ] || [ "$time_limit" = "infinite" ]; then
173+
case "{{partition}}" in
174+
short|interactive)
175+
time_limit="12:00:00"
176+
;;
177+
medium)
178+
time_limit="2-00:00:00"
179+
;;
180+
long)
181+
time_limit="5-00:00:00"
182+
;;
183+
*)
184+
time_limit="7-00:00:00"
185+
;;
186+
esac
187+
fi
188+
189+
sbatch --job-name="pytc_{{cmd}}" \
190+
--partition={{partition}} \
191+
--output=slurm_outputs/slurm-%j.out \
192+
--error=slurm_outputs/slurm-%j.err \
193+
--nodes=1 \
194+
--ntasks={{num_gpu}} \
195+
--gpus-per-node={{num_gpu}} \
196+
--cpus-per-task={{num_cpu}} \
197+
--mem=32G \
198+
--time=$time_limit \
199+
$constraint_flag \
200+
--wrap="source /projects/weilab/weidf/lib/miniconda3/bin/activate pytc && cd $PWD && srun --ntasks={{num_gpu}} --ntasks-per-node={{num_gpu}} {{cmd}}"
128201
# ============================================================================
129202
# Visualization Commands
130203
# ============================================================================

tutorials/cellmap_cos7.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ system:
2727
num_gpus: 4
2828
num_cpus: 8
2929
num_workers: 8
30-
batch_size: 8 # Per GPU batch size
30+
batch_size: 8 # Per GPU (effective = 32)
3131
inference:
3232
num_gpus: 1
3333
num_cpus: 1

0 commit comments

Comments
 (0)