Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 2a48e2d

Browse files
Update train.py
Sometimes the smallest changes make the biggest difference. Also, add text encoder training params.
1 parent eeb5cc4 commit 2a48e2d

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

train.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,9 @@ def finetune_unet(batch, train_encoder=False):
377377

378378
# Set noise scheduler to cosine (this can be done via config, but this ensures it's enabled)
379379
#noise_scheduler.beta_schedule = "squaredcos_cap_v2"
380-
380+
381+
unet.train()
382+
381383
# Convert videos to latent space
382384
pixel_values = batch["pixel_values"].to(weight_dtype)
383385

@@ -398,6 +400,10 @@ def finetune_unet(batch, train_encoder=False):
398400
# Enable text encoder training
399401
if train_encoder:
400402
text_encoder.train()
403+
cast_to_gpu_and_type([text_encoder], accelerator, torch.float32)
404+
text_encoder.requires_grad_(True)
405+
else:
406+
text_encoder.requires_grad_(False)
401407

402408
enable_trainable_unet_modules(unet, trainable_modules, is_enabled=True)
403409

@@ -420,7 +426,6 @@ def finetune_unet(batch, train_encoder=False):
420426

421427
for epoch in range(first_epoch, num_train_epochs):
422428
train_loss = 0.0
423-
unet.train()
424429

425430
for step, batch in enumerate(train_dataloader):
426431
# Skip steps until we reach the resumed step
@@ -483,7 +488,9 @@ def finetune_unet(batch, train_encoder=False):
483488
if global_step == 1: print("Performing validation prompt.")
484489
if accelerator.is_main_process:
485490
with accelerator.autocast():
486-
491+
unet.eval()
492+
text_encoder.eval()
493+
487494
pipeline = TextToVideoSDPipeline.from_pretrained(
488495
pretrained_model_path,
489496
text_encoder=text_encoder,

0 commit comments

Comments
 (0)