Skip to content

Commit 0aa88c3

Browse files
authored
Migrate to Flax (#17)
1 parent 4dcac5f commit 0aa88c3

File tree

162 files changed

+5404
-8372
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

162 files changed

+5404
-8372
lines changed

.github/workflows/unit-test.yml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,10 @@ jobs:
3232
run: |
3333
python -m pip install --upgrade pip
3434
pip install tensorflow-cpu==2.12.0
35-
pip install jax==0.4.8
36-
pip install jaxlib==0.4.7
35+
pip install jax==0.4.14
36+
pip install jaxlib==0.4.14
3737
pip install -r docker/requirements.txt
38-
pip install -e imgx
39-
pip install -e imgx_datasets
38+
pip install -e .
4039
- name: Test with pytest
4140
run: |
42-
pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow" imgx/tests/unit
43-
pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow" imgx_datasets/tests
41+
pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow"

.isort.cfg

Lines changed: 0 additions & 7 deletions
This file was deleted.

.pre-commit-config.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ default_language_version:
22
python: python3
33
repos:
44
- repo: https://github.com/pre-commit/pre-commit-hooks
5-
rev: v4.4.0
5+
rev: v4.5.0
66
hooks:
77
- id: check-added-large-files
88
- id: check-ast
@@ -27,13 +27,13 @@ repos:
2727
hooks:
2828
- id: isort
2929
- repo: https://github.com/psf/black
30-
rev: 23.7.0
30+
rev: 23.10.0
3131
hooks:
3232
- id: black
3333
args:
34-
- --line-length=80
34+
- --line-length=100
3535
- repo: https://github.com/pre-commit/mirrors-mypy
36-
rev: v1.4.1
36+
rev: v1.6.1
3737
hooks: # https://github.com/python/mypy/issues/4008#issuecomment-582458665
3838
- id: mypy
3939
name: mypy-imgx
@@ -72,23 +72,23 @@ repos:
7272
--warn-unreachable,
7373
]
7474
- repo: https://github.com/pre-commit/mirrors-prettier
75-
rev: v3.0.0
75+
rev: v3.0.3
7676
hooks:
7777
- id: prettier
7878
args:
79-
- --print-width=80
79+
- --print-width=100
8080
- --prose-wrap=always
8181
- --tab-width=2
8282
- repo: https://github.com/charliermarsh/ruff-pre-commit
83-
rev: "v0.0.280"
83+
rev: "v0.1.1"
8484
hooks:
8585
- id: ruff
8686
- repo: https://github.com/pre-commit/mirrors-pylint
8787
rev: v3.0.0a5
8888
hooks:
8989
- id: pylint
9090
- repo: https://github.com/asottile/pyupgrade
91-
rev: v3.9.0
91+
rev: v3.15.0
9292
hooks:
9393
- id: pyupgrade
9494
args:

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ generated-members=
173173
[FORMAT]
174174

175175
# Maximum number of characters on a single line.
176-
max-line-length=80
176+
max-line-length=100
177177

178178
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
179179
# lines made too long by directives to pytype.

Makefile

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
11
pip:
2-
pip install -e imgx
3-
pip install -e imgx_datasets
2+
pip install -e .
43

54
test:
6-
pytest --cov=imgx -n 4 imgx/tests -x
7-
pytest --cov=imgx_datasets -n 4 imgx_datasets/tests -x
5+
pytest --cov=imgx -n 4 imgx
6+
pytest --cov=imgx_datasets -n 4 imgx_datasets
87

98
build_dataset:
10-
tfds build imgx_datasets/imgx_datasets/male_pelvic_mr &
11-
tfds build imgx_datasets/imgx_datasets/amos_ct &
12-
tfds build imgx_datasets/imgx_datasets/muscle_us &
13-
tfds build imgx_datasets/imgx_datasets/brats2021_mr &
14-
15-
rebuild_dataset:
16-
tfds build imgx_datasets/imgx_datasets/male_pelvic_mr --overwrite &
17-
tfds build imgx_datasets/imgx_datasets/amos_ct --overwrite &
18-
tfds build imgx_datasets/imgx_datasets/muscle_us --overwrite &
19-
tfds build imgx_datasets/imgx_datasets/brats2021_mr --overwrite &
9+
tfds build imgx_datasets/male_pelvic_mr
10+
tfds build imgx_datasets/amos_ct
11+
tfds build imgx_datasets/muscle_us
12+
tfds build imgx_datasets/brats2021_mr

README.md

Lines changed: 74 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,80 @@
1-
# A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models
1+
# ImgX-DiffSeg
22

3-
:tada: This is a follow-up work of Importance of Aligning Training Strategy with
4-
Evaluation for Diffusion Models in 3D Multiclass Segmentation
5-
([paper](https://arxiv.org/abs/2303.06040),
6-
[code](https://github.com/mathpluscode/ImgX-DiffSeg/tree/v0.1.0)), with better
7-
recycling method, better network, more baseline training methods (including
8-
self-conditioning) on four data sets (muscle ultrasound, male pelvic MR,
9-
abdominal CT, brain MR).
3+
ImgX-DiffSeg is a Jax-based deep learning toolkit using Flax for biomedical image segmentations.
104

11-
:bookmark_tabs: The preprint is available on
12-
[arXiv](https://arxiv.org/abs/2308.16355).
5+
This repository includes the implementation of the following work
6+
7+
- [A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models](https://arxiv.org/abs/2308.16355)
8+
- [Importance of Aligning Training Strategy with Evaluation for Diffusion Models in 3D Multiclass Segmentation](https://arxiv.org/abs/2303.06040)
139

1410
<div>
1511
<img src="images/diffusion_training_strategy_diagram.png" width="600" alt="diffusion_training_strategy_diagram"></img>
1612
</div>
1713

18-
---
19-
20-
ImgX is a Jax-based deep learning toolkit for biomedical image segmentations.
14+
## Features
2115

2216
Current supported functionalities are summarized as follows.
2317

2418
**Data sets**
2519

26-
See the [readme](imgx_datasets/README.md) for details on training, validation,
27-
and test splits.
20+
See the [readme](imgx_datasets/README.md) for further details.
2821

29-
- [x] Muscle ultrasound from
30-
[Marzola et al. 2021](https://data.mendeley.com/datasets/3jykz7wz8d/1).
31-
- [x] Male pelvic MR from
32-
[Li et al. 2022](https://zenodo.org/record/7013610#.Y1U95-zMKrM).
33-
- [x] AMOS CT from
34-
[Ji et al. 2022](https://zenodo.org/record/7155725#.ZAN4BuzP2rO).
35-
- [x] Brain MR from [Baid et al. 2021](https://arxiv.org/abs/2107.02314).
22+
- Muscle ultrasound from [Marzola et al. 2021](https://data.mendeley.com/datasets/3jykz7wz8d/1).
23+
- Male pelvic MR from [Li et al. 2022](https://zenodo.org/record/7013610#.Y1U95-zMKrM).
24+
- AMOS CT from [Ji et al. 2022](https://zenodo.org/record/7155725#.ZAN4BuzP2rO).
25+
- Brain MR from [Baid et al. 2021](https://arxiv.org/abs/2107.02314).
3626

3727
**Algorithms**
3828

39-
- [x] Supervised segmentation.
40-
- [x] Diffusion-based segmentation.
41-
- [x] Gaussian noise based diffusion.
42-
- [x] Prediction of noise or ground truth.
43-
- [x] Training with recycling or self-conditioning.
29+
- Supervised segmentation.
30+
- Diffusion-based segmentation.
31+
- [Gaussian noise based diffusion](https://arxiv.org/abs/2211.00611).
32+
- Noise prediction ([epsilon-parameterization](https://arxiv.org/abs/2006.11239)) or ground truth
33+
prediction ([x0-parameterization](https://arxiv.org/abs/2102.09672)).
34+
- [Importance sampling](https://arxiv.org/abs/2102.09672) for timestep.
35+
- Recycling training strategies, including [xt-recycling](https://arxiv.org/abs/2303.06040) and
36+
[xT-recycling](https://arxiv.org/abs/2308.16355).
37+
- Self-conditioning training strategies, including
38+
[Chen et al. 2022](https://arxiv.org/abs/2208.04202) and
39+
[Watson et al. 2023.](https://www.nature.com/articles/s41586-023-06415-8).
4440

4541
**Models**
4642

47-
- [x] U-Net with Transformers supporting 2D and 3D images.
43+
- [U-Net](https://arxiv.org/abs/1505.04597) with [Transformers](https://arxiv.org/abs/1706.03762)
44+
supporting 2D and 3D images.
45+
- [Efficient attention](https://arxiv.org/abs/2112.05682).
4846

4947
**Training**
5048

51-
- [x] Patch-based training.
52-
- [x] Multi-device training (one model per device).
53-
- [x] Mixed precision training.
54-
- [x] Gradient clipping and accumulation.
49+
- Patch-based training.
50+
- Multi-device training (one model per device) with
51+
[`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html).
52+
- Mixed precision training.
53+
- Gradient clipping and accumulation.
54+
- [Early stopping](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html).
55+
56+
**Changelog**
5557

56-
---
58+
- October 2023: Migrated from [Haiku](https://github.com/google-deepmind/dm-haiku) to
59+
[Flax](https://github.com/google/flax) following Google DeepMind's recommendation.
5760

5861
## Installation
5962

6063
### TPU with Docker
6164

62-
The following instructions have been tested only for TPU-v3-8. The docker
63-
container uses root user.
65+
The following instructions have been tested only for TPU-v3-8. The docker container uses root user.
6466

65-
1. Build the docker image inside the repository.
67+
1. TPU often has limited disk space.
68+
[RAM disk](https://www.linuxbabe.com/command-line/create-ramdisk-linux) can be used to help.
69+
70+
```bash
71+
sudo mkdir /tmp/ramdisk
72+
sudo chmod 777 /tmp/ramdisk
73+
sudo mount -t tmpfs -o size=256G imgxramdisk /tmp/ramdisk
74+
cd /tmp/ramdisk/
75+
```
76+
77+
2. Build the docker image inside the repository.
6678

6779
```bash
6880
sudo docker build --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) -f docker/Dockerfile.tpu -t imgx .
@@ -74,7 +86,7 @@ container uses root user.
7486
- `-f` provides the docker file.
7587
- `-t` tag the docker image.
7688

77-
2. Run the Docker container.
89+
3. Run the Docker container.
7890

7991
```bash
8092
mkdir -p $(cd ../ && pwd)/tensorflow_datasets
@@ -84,27 +96,16 @@ container uses root user.
8496
imgx bash
8597
```
8698

87-
3. Install the package inside container.
99+
4. Install the package inside container.
88100

89101
```bash
90102
make pip
91103
```
92104

93-
TPU often has limited disk space.
94-
[RAM disk](https://www.linuxbabe.com/command-line/create-ramdisk-linux) can be
95-
used to help.
96-
97-
```bash
98-
sudo mkdir /tmp/ramdisk
99-
sudo chmod 777 /tmp/ramdisk
100-
sudo mount -t tmpfs -o size=256G imgxramdisk /tmp/ramdisk
101-
cd /tmp/ramdisk/
102-
```
103-
104105
### GPU with Docker
105106

106-
The following instructions have been tested only for CUDA == 11.4.1 and CUDNN ==
107-
8.2.0. The docker container uses non-root user.
107+
The following instructions have been tested only for CUDA == 11.4.1 and CUDNN == 8.2.0. The docker
108+
container uses non-root user.
108109
[Docker image used may be removed.](https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/support-policy.md)
109110

110111
1. Build the docker image inside the repository.
@@ -155,8 +156,8 @@ conda env update -f docker/environment_mac_m1.yml
155156

156157
#### Install Conda for Linux / Mac Intel
157158

158-
[Install Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html)
159-
and then create the environment.
159+
[Install Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) and
160+
then create the environment.
160161

161162
```bash
162163
conda install -y -n base conda-libmamba-solver
@@ -175,9 +176,8 @@ make pip
175176

176177
## Build Data Sets
177178

178-
Use the following commands to (re)build all data sets. Check the
179-
[README](imgx_datasets/README.md) of imgx_datasets for details. Especially,
180-
manual downloading is required for the BraTS 2021 dataset.
179+
Use the following commands to (re)build all data sets. Check the [README](imgx_datasets/README.md)
180+
of imgx_datasets for details. Especially, manual downloading is required for the BraTS 2021 dataset.
181181

182182
```bash
183183
make build_dataset
@@ -188,15 +188,16 @@ make rebuild_dataset
188188

189189
### Training and Testing
190190

191-
Example command to use two GPUs for training, validation and testing. The
192-
outputs are stored under `wandb/latest-run/files/`, where
191+
Example command to use two GPUs for training, validation and testing. The outputs are stored under
192+
`wandb/latest-run/files/`, where
193193

194194
- `ckpt` stores the model checkpoints and corresponding validation metrics.
195195
- `test_evaluation` stores the prediction on test set and corresponding metrics.
196196

197197
```bash
198198
# limit to two GPUs if using NVIDIA GPUs
199199
export CUDA_VISIBLE_DEVICES="0,1"
200+
200201
# select data set to use
201202
export DATASET_NAME="male_pelvic_mr"
202203
export DATASET_NAME="amos_ct"
@@ -216,16 +217,14 @@ imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
216217
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
217218
```
218219

219-
```bash
220-
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --num_seeds 3
221-
```
222-
223-
Optionally, for debug purposes, use flag `debug=True` to run the experiment with
224-
a small dataset and smaller models.
220+
Optionally, for debug purposes, use flag `debug=True` to run the experiment with a small dataset and
221+
smaller models.
225222

226223
```bash
227224
imgx_train --config-name config_${DATASET_NAME}_seg debug=True
225+
imgx_test --log_dir wandb/latest-run/
228226
imgx_train --config-name config_${DATASET_NAME}_diff_seg debug=True
227+
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM
229228
```
230229

231230
## Code Quality
@@ -248,11 +247,12 @@ pre-commit run --all-files
248247

249248
### Code Test
250249

251-
Run the command below to test and get coverage report. As JAX tests requires two
252-
CPUs, `-n 4` uses 4 threads, therefore requires 8 CPUs in total.
250+
Run the command below to test and get coverage report. As JAX tests requires two CPUs, `-n 4` uses 4
251+
threads, therefore requires 8 CPUs in total.
253252

254253
```bash
255-
pytest --cov=imgx -n 4 tests
254+
pytest --cov=imgx -n 4 imgx
255+
pytest --cov=imgx_datasets -n 4 imgx_datasets
256256
```
257257

258258
## References
@@ -266,21 +266,19 @@ pytest --cov=imgx -n 4 tests
266266
- [Scenic (JAX)](https://github.com/google-research/scenic/)
267267
- [DeepMind Research (JAX)](https://github.com/deepmind/deepmind-research/tree/master/ogb_lsc/)
268268
- [Haiku (JAX)](https://github.com/deepmind/dm-haiku/)
269+
- [Flax (JAX)](https://github.com/google/flax)
269270

270271
## Acknowledgement
271272

272-
This work was supported by the EPSRC grant (EP/T029404/1), the Wellcome/EPSRC
273-
Centre for Interventional and Surgical Sciences (203145Z/16/Z), the
274-
International Alliance for Cancer Early Detection, an alliance between Cancer
275-
Research UK (C28070/A30912, C73666/A31378), Canary Center at Stanford
276-
University, the University of Cambridge, OHSU Knight Cancer Institute,
277-
University College London and the University of Manchester, and Cloud TPUs from
278-
Google's TPU Research Cloud (TRC).
273+
This work was supported by the EPSRC grant (EP/T029404/1), the Wellcome/EPSRC Centre for
274+
Interventional and Surgical Sciences (203145Z/16/Z), the International Alliance for Cancer Early
275+
Detection, an alliance between Cancer Research UK (C28070/A30912, C73666/A31378), Canary Center at
276+
Stanford University, the University of Cambridge, OHSU Knight Cancer Institute, University College
277+
London and the University of Manchester, and Cloud TPUs from Google's TPU Research Cloud (TRC).
279278

280279
## Citation
281280

282-
If you find the code base and method useful in your research, please cite the
283-
relevant paper:
281+
If you find the code base and method useful in your research, please cite the relevant paper:
284282

285283
```bibtex
286284
@article{fu2023recycling,

docker/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ COPY docker/requirements.txt /${USER}/requirements.txt
7575

7676
RUN /${USER}/conda/bin/pip3 install --upgrade pip \
7777
&& /${USER}/conda/bin/pip3 install \
78-
jax==0.4.8 \
79-
jaxlib==0.4.7+cuda11.cudnn86 \
78+
jax==0.4.14 \
79+
jaxlib==0.4.14+cuda11.cudnn86 \
8080
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
8181
&& /${USER}/conda/bin/pip3 install tensorflow-cpu==2.12.0 \
8282
&& /${USER}/conda/bin/pip3 install -r /${USER}/requirements.txt

0 commit comments

Comments
 (0)