@@ -377,7 +377,9 @@ def finetune_unet(batch, train_encoder=False):
377377
378378 # Set noise scheduler to cosine (this can be done via config, but this ensures it's enabled)
379379 #noise_scheduler.beta_schedule = "squaredcos_cap_v2"
380-
380+
381+ unet .train ()
382+
381383 # Convert videos to latent space
382384 pixel_values = batch ["pixel_values" ].to (weight_dtype )
383385
@@ -398,6 +400,10 @@ def finetune_unet(batch, train_encoder=False):
398400 # Enable text encoder training
399401 if train_encoder :
400402 text_encoder .train ()
403+ cast_to_gpu_and_type ([text_encoder ], accelerator , torch .float32 )
404+ text_encoder .requires_grad_ (True )
405+ else :
406+ text_encoder .requires_grad_ (False )
401407
402408 enable_trainable_unet_modules (unet , trainable_modules , is_enabled = True )
403409
@@ -420,7 +426,6 @@ def finetune_unet(batch, train_encoder=False):
420426
421427 for epoch in range (first_epoch , num_train_epochs ):
422428 train_loss = 0.0
423- unet .train ()
424429
425430 for step , batch in enumerate (train_dataloader ):
426431 # Skip steps until we reach the resumed step
@@ -483,7 +488,9 @@ def finetune_unet(batch, train_encoder=False):
483488 if global_step == 1 : print ("Performing validation prompt." )
484489 if accelerator .is_main_process :
485490 with accelerator .autocast ():
486-
491+ unet .eval ()
492+ text_encoder .eval ()
493+
487494 pipeline = TextToVideoSDPipeline .from_pretrained (
488495 pretrained_model_path ,
489496 text_encoder = text_encoder ,
0 commit comments