Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 7640ce3

Browse files
Make gradient checkpointing work properly
1 parent da46f92 commit 7640ce3

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,12 @@ def finetune_unet(batch, train_encoder=False):
722722
)
723723
cast_to_gpu_and_type([text_encoder], accelerator, torch.float32)
724724

725-
725+
# Fixes gradient checkpointing training.
726+
# See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
727+
if gradient_checkpointing or text_encoder_gradient_checkpointing:
728+
unet.eval()
729+
text_encoder.eval()
730+
726731
# Encode text embeddings
727732
token_ids = batch['prompt_ids']
728733
encoder_hidden_states = text_encoder(token_ids)[0]

0 commit comments

Comments
 (0)