Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit ab143eb

Browse files
committed
Add flag for Torch 2 attention
1 parent 0707da9 commit ab143eb

File tree

4 files changed

+14
-3
lines changed

4 files changed

+14
-3
lines changed

configs/my_config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,6 @@ seed: 64
3838
mixed_precision: "fp16"
3939
use_8bit_adam: False # This seems to be incompatible at the moment.
4040
enable_xformers_memory_efficient_attention: False
41+
42+
# Use scaled dot product attention (Only available with >= Torch 2.0)
43+
enable_torch_2_attn: True

configs/offset_noise_finetune.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,6 @@ use_8bit_adam: False # This seems to be incompatible at the moment.
4545

4646
# Xformers must be installed
4747
enable_xformers_memory_efficient_attention: True
48+
49+
# Use scaled dot product attention (Only available with >= Torch 2.0)
50+
enable_torch_2_attn: True

configs/single_video_config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,6 @@ seed: 64
3838
mixed_precision: "fp16"
3939
use_8bit_adam: False # This seems to be incompatible at the moment.
4040
enable_xformers_memory_efficient_attention: False
41+
42+
# Use scaled dot product attention (Only available with >= Torch 2.0)
43+
enable_torch_2_attn: True

train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def set_torch_2_attn(unet):
101101
if optim_count > 0:
102102
print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
103103

104-
def handle_memory_attention(enable_xformers_memory_efficient_attention, unet):
104+
def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet):
105105
try:
106106
is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
107107

@@ -111,7 +111,8 @@ def handle_memory_attention(enable_xformers_memory_efficient_attention, unet):
111111
unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
112112
else:
113113
raise ValueError("xformers is not available. Make sure it is installed correctly")
114-
else:
114+
115+
if enable_torch_2_attn and is_torch_2:
115116
set_torch_2_attn(unet)
116117
except:
117118
print("Could not enable memory efficient attention for xformers or Torch 2.0.")
@@ -230,6 +231,7 @@ def main(
230231
mixed_precision: Optional[str] = "fp16",
231232
use_8bit_adam: bool = False,
232233
enable_xformers_memory_efficient_attention: bool = True,
234+
enable_torch_2_attn: bool = False,
233235
seed: Optional[int] = None,
234236
train_text_encoder: bool = False,
235237
use_offset_noise: bool = False,
@@ -268,7 +270,7 @@ def main(
268270
freeze_models([vae, text_encoder, unet])
269271

270272
# Enable xformers if available
271-
handle_memory_attention(enable_xformers_memory_efficient_attention, unet)
273+
handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet)
272274

273275
if scale_lr:
274276
learning_rate = (

0 commit comments

Comments
 (0)