1- # A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models
1+ # ImgX-DiffSeg
22
3- :tada : This is a follow-up work of Importance of Aligning Training Strategy with
4- Evaluation for Diffusion Models in 3D Multiclass Segmentation
5- ([ paper] ( https://arxiv.org/abs/2303.06040 ) ,
6- [ code] ( https://github.com/mathpluscode/ImgX-DiffSeg/tree/v0.1.0 ) ), with better
7- recycling method, better network, more baseline training methods (including
8- self-conditioning) on four data sets (muscle ultrasound, male pelvic MR,
9- abdominal CT, brain MR).
3+ ImgX-DiffSeg is a Jax-based deep learning toolkit using Flax for biomedical image segmentations.
104
11- :bookmark_tabs : The preprint is available on
12- [ arXiv] ( https://arxiv.org/abs/2308.16355 ) .
5+ This repository includes the implementation of the following work
6+
7+ - [ A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models] ( https://arxiv.org/abs/2308.16355 )
8+ - [ Importance of Aligning Training Strategy with Evaluation for Diffusion Models in 3D Multiclass Segmentation] ( https://arxiv.org/abs/2303.06040 )
139
1410<div >
1511<img src =" images/diffusion_training_strategy_diagram.png " width =" 600 " alt =" diffusion_training_strategy_diagram " ></img >
1612</div >
1713
18- ---
19-
20- ImgX is a Jax-based deep learning toolkit for biomedical image segmentations.
14+ ## Features
2115
2216Current supported functionalities are summarized as follows.
2317
2418** Data sets**
2519
26- See the [ readme] ( imgx_datasets/README.md ) for details on training, validation,
27- and test splits.
20+ See the [ readme] ( imgx_datasets/README.md ) for further details.
2821
29- - [x] Muscle ultrasound from
30- [ Marzola et al. 2021] ( https://data.mendeley.com/datasets/3jykz7wz8d/1 ) .
31- - [x] Male pelvic MR from
32- [ Li et al. 2022] ( https://zenodo.org/record/7013610#.Y1U95-zMKrM ) .
33- - [x] AMOS CT from
34- [ Ji et al. 2022] ( https://zenodo.org/record/7155725#.ZAN4BuzP2rO ) .
35- - [x] Brain MR from [ Baid et al. 2021] ( https://arxiv.org/abs/2107.02314 ) .
22+ - Muscle ultrasound from [ Marzola et al. 2021] ( https://data.mendeley.com/datasets/3jykz7wz8d/1 ) .
23+ - Male pelvic MR from [ Li et al. 2022] ( https://zenodo.org/record/7013610#.Y1U95-zMKrM ) .
24+ - AMOS CT from [ Ji et al. 2022] ( https://zenodo.org/record/7155725#.ZAN4BuzP2rO ) .
25+ - Brain MR from [ Baid et al. 2021] ( https://arxiv.org/abs/2107.02314 ) .
3626
3727** Algorithms**
3828
39- - [x] Supervised segmentation.
40- - [x] Diffusion-based segmentation.
41- - [x] Gaussian noise based diffusion.
42- - [x] Prediction of noise or ground truth.
43- - [x] Training with recycling or self-conditioning.
29+ - Supervised segmentation.
30+ - Diffusion-based segmentation.
31+ - [ Gaussian noise based diffusion] ( https://arxiv.org/abs/2211.00611 ) .
32+ - Noise prediction ([ epsilon-parameterization] ( https://arxiv.org/abs/2006.11239 ) ) or ground truth
33+ prediction ([ x0-parameterization] ( https://arxiv.org/abs/2102.09672 ) ).
34+ - [ Importance sampling] ( https://arxiv.org/abs/2102.09672 ) for timestep.
35+ - Recycling training strategies, including [ xt-recycling] ( https://arxiv.org/abs/2303.06040 ) and
36+ [ xT-recycling] ( https://arxiv.org/abs/2308.16355 ) .
37+ - Self-conditioning training strategies, including
38+ [ Chen et al. 2022] ( https://arxiv.org/abs/2208.04202 ) and
39+ [ Watson et al. 2023.] ( https://www.nature.com/articles/s41586-023-06415-8 ) .
4440
4541** Models**
4642
47- - [x] U-Net with Transformers supporting 2D and 3D images.
43+ - [ U-Net] ( https://arxiv.org/abs/1505.04597 ) with [ Transformers] ( https://arxiv.org/abs/1706.03762 )
44+ supporting 2D and 3D images.
45+ - [ Efficient attention] ( https://arxiv.org/abs/2112.05682 ) .
4846
4947** Training**
5048
51- - [x] Patch-based training.
52- - [x] Multi-device training (one model per device).
53- - [x] Mixed precision training.
54- - [x] Gradient clipping and accumulation.
49+ - Patch-based training.
50+ - Multi-device training (one model per device) with
51+ [ ` pmap ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html ) .
52+ - Mixed precision training.
53+ - Gradient clipping and accumulation.
54+ - [ Early stopping] ( https://flax.readthedocs.io/en/latest/api_reference/flax.training.html ) .
55+
56+ ** Changelog**
5557
56- ---
58+ - October 2023: Migrated from [ Haiku] ( https://github.com/google-deepmind/dm-haiku ) to
59+ [ Flax] ( https://github.com/google/flax ) following Google DeepMind's recommendation.
5760
5861## Installation
5962
6063### TPU with Docker
6164
62- The following instructions have been tested only for TPU-v3-8. The docker
63- container uses root user.
65+ The following instructions have been tested only for TPU-v3-8. The docker container uses root user.
6466
65- 1 . Build the docker image inside the repository.
67+ 1 . TPU often has limited disk space.
68+ [ RAM disk] ( https://www.linuxbabe.com/command-line/create-ramdisk-linux ) can be used to help.
69+
70+ ``` bash
71+ sudo mkdir /tmp/ramdisk
72+ sudo chmod 777 /tmp/ramdisk
73+ sudo mount -t tmpfs -o size=256G imgxramdisk /tmp/ramdisk
74+ cd /tmp/ramdisk/
75+ ```
76+
77+ 2 . Build the docker image inside the repository.
6678
6779 ``` bash
6880 sudo docker build --build-arg USER_ID=$( id -u) --build-arg GROUP_ID=$( id -g) -f docker/Dockerfile.tpu -t imgx .
@@ -74,7 +86,7 @@ container uses root user.
7486 - ` -f ` provides the docker file.
7587 - ` -t ` tag the docker image.
7688
77- 2 . Run the Docker container.
89+ 3 . Run the Docker container.
7890
7991 ``` bash
8092 mkdir -p $( cd ../ && pwd) /tensorflow_datasets
@@ -84,27 +96,16 @@ container uses root user.
8496 imgx bash
8597 ```
8698
87- 3 . Install the package inside container.
99+ 4 . Install the package inside container.
88100
89101 ``` bash
90102 make pip
91103 ```
92104
93- TPU often has limited disk space.
94- [ RAM disk] ( https://www.linuxbabe.com/command-line/create-ramdisk-linux ) can be
95- used to help.
96-
97- ``` bash
98- sudo mkdir /tmp/ramdisk
99- sudo chmod 777 /tmp/ramdisk
100- sudo mount -t tmpfs -o size=256G imgxramdisk /tmp/ramdisk
101- cd /tmp/ramdisk/
102- ```
103-
104105### GPU with Docker
105106
106- The following instructions have been tested only for CUDA == 11.4.1 and CUDNN ==
107- 8.2.0. The docker container uses non-root user.
107+ The following instructions have been tested only for CUDA == 11.4.1 and CUDNN == 8.2.0. The docker
108+ container uses non-root user.
108109[ Docker image used may be removed.] ( https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/support-policy.md )
109110
1101111 . Build the docker image inside the repository.
@@ -155,8 +156,8 @@ conda env update -f docker/environment_mac_m1.yml
155156
156157#### Install Conda for Linux / Mac Intel
157158
158- [ Install Conda] ( https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html )
159- and then create the environment.
159+ [ Install Conda] ( https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html ) and
160+ then create the environment.
160161
161162``` bash
162163conda install -y -n base conda-libmamba-solver
@@ -175,9 +176,8 @@ make pip
175176
176177## Build Data Sets
177178
178- Use the following commands to (re)build all data sets. Check the
179- [ README] ( imgx_datasets/README.md ) of imgx_datasets for details. Especially,
180- manual downloading is required for the BraTS 2021 dataset.
179+ Use the following commands to (re)build all data sets. Check the [ README] ( imgx_datasets/README.md )
180+ of imgx_datasets for details. Especially, manual downloading is required for the BraTS 2021 dataset.
181181
182182``` bash
183183make build_dataset
@@ -188,15 +188,16 @@ make rebuild_dataset
188188
189189### Training and Testing
190190
191- Example command to use two GPUs for training, validation and testing. The
192- outputs are stored under ` wandb/latest-run/files/ ` , where
191+ Example command to use two GPUs for training, validation and testing. The outputs are stored under
192+ ` wandb/latest-run/files/ ` , where
193193
194194- ` ckpt ` stores the model checkpoints and corresponding validation metrics.
195195- ` test_evaluation ` stores the prediction on test set and corresponding metrics.
196196
197197``` bash
198198# limit to two GPUs if using NVIDIA GPUs
199199export CUDA_VISIBLE_DEVICES=" 0,1"
200+
200201# select data set to use
201202export DATASET_NAME=" male_pelvic_mr"
202203export DATASET_NAME=" amos_ct"
@@ -216,16 +217,14 @@ imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
216217imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
217218```
218219
219- ``` bash
220- imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --num_seeds 3
221- ```
222-
223- Optionally, for debug purposes, use flag ` debug=True ` to run the experiment with
224- a small dataset and smaller models.
220+ Optionally, for debug purposes, use flag ` debug=True ` to run the experiment with a small dataset and
221+ smaller models.
225222
226223``` bash
227224imgx_train --config-name config_${DATASET_NAME} _seg debug=True
225+ imgx_test --log_dir wandb/latest-run/
228226imgx_train --config-name config_${DATASET_NAME} _diff_seg debug=True
227+ imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM
229228```
230229
231230## Code Quality
@@ -248,11 +247,12 @@ pre-commit run --all-files
248247
249248### Code Test
250249
251- Run the command below to test and get coverage report. As JAX tests requires two
252- CPUs, ` -n 4 ` uses 4 threads, therefore requires 8 CPUs in total.
250+ Run the command below to test and get coverage report. As JAX tests requires two CPUs, ` -n 4 ` uses 4
251+ threads, therefore requires 8 CPUs in total.
253252
254253``` bash
255- pytest --cov=imgx -n 4 tests
254+ pytest --cov=imgx -n 4 imgx
255+ pytest --cov=imgx_datasets -n 4 imgx_datasets
256256```
257257
258258## References
@@ -266,21 +266,19 @@ pytest --cov=imgx -n 4 tests
266266- [ Scenic (JAX)] ( https://github.com/google-research/scenic/ )
267267- [ DeepMind Research (JAX)] ( https://github.com/deepmind/deepmind-research/tree/master/ogb_lsc/ )
268268- [ Haiku (JAX)] ( https://github.com/deepmind/dm-haiku/ )
269+ - [ Flax (JAX)] ( https://github.com/google/flax )
269270
270271## Acknowledgement
271272
272- This work was supported by the EPSRC grant (EP/T029404/1), the Wellcome/EPSRC
273- Centre for Interventional and Surgical Sciences (203145Z/16/Z), the
274- International Alliance for Cancer Early Detection, an alliance between Cancer
275- Research UK (C28070/A30912, C73666/A31378), Canary Center at Stanford
276- University, the University of Cambridge, OHSU Knight Cancer Institute,
277- University College London and the University of Manchester, and Cloud TPUs from
278- Google's TPU Research Cloud (TRC).
273+ This work was supported by the EPSRC grant (EP/T029404/1), the Wellcome/EPSRC Centre for
274+ Interventional and Surgical Sciences (203145Z/16/Z), the International Alliance for Cancer Early
275+ Detection, an alliance between Cancer Research UK (C28070/A30912, C73666/A31378), Canary Center at
276+ Stanford University, the University of Cambridge, OHSU Knight Cancer Institute, University College
277+ London and the University of Manchester, and Cloud TPUs from Google's TPU Research Cloud (TRC).
279278
280279## Citation
281280
282- If you find the code base and method useful in your research, please cite the
283- relevant paper:
281+ If you find the code base and method useful in your research, please cite the relevant paper:
284282
285283``` bibtex
286284@article{fu2023recycling,
0 commit comments