Skip to content

Commit 364be35

Browse files
committed
Fix block size handling for cross attention in TPU flash attention
1 parent 7d25dc9 commit 364be35

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,19 @@ def _tpu_flash_attention(
186186
kv_max_block_size = key.shape[1]
187187
else:
188188
kv_max_block_size = q_max_block_size
189-
if flash_block_sizes:
189+
# ensure that for cross attention we override the block sizes.
190+
if flash_block_sizes and key.shape[1] == query.shape[1]:
190191
block_sizes = flash_block_sizes
191192
else:
193+
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
192194
block_sizes = splash_attention_kernel.BlockSizes(
193-
block_q=min(q_max_block_size, query.shape[2]),
195+
block_q=block_size_q,
194196
block_kv_compute=min(kv_max_block_size, key.shape[2]),
195197
block_kv=min(kv_max_block_size, key.shape[2]),
196-
block_q_dkv=min(q_max_block_size, query.shape[2]),
198+
block_q_dkv=block_size_q,
197199
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
198200
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
199-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq,
201+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
200202
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
201203
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
202204
)
@@ -215,7 +217,6 @@ def _tpu_flash_attention(
215217
check_rep=False,
216218
)
217219
def wrap_flash_attention(query, key, value):
218-
219220
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
220221
block_q_sizes = (
221222
block_sizes.block_q,
@@ -1042,7 +1043,6 @@ def setup(self):
10421043
)
10431044

10441045
def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):
1045-
10461046
qkv_proj = self.qkv(hidden_states)
10471047
B, L = hidden_states.shape[:2]
10481048
H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3
@@ -1054,7 +1054,6 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non
10541054
key_proj = self.key_norm(key_proj)
10551055

10561056
if encoder_hidden_states is not None:
1057-
10581057
encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states)
10591058
B, L = encoder_hidden_states.shape[:2]
10601059
H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3
@@ -1148,7 +1147,6 @@ class FlaxAttention(nn.Module):
11481147
quant: Quant = None
11491148

11501149
def setup(self):
1151-
11521150
if self.attention_kernel == "flash" and self.mesh is None:
11531151
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
11541152
inner_dim = self.dim_head * self.heads

0 commit comments

Comments
 (0)