Skip to content

Commit ea75256

Browse files
committed
Fix block size handling for cross attention in TPU flash attention
1 parent 7d25dc9 commit ea75256

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,19 @@ def _tpu_flash_attention(
186186
kv_max_block_size = key.shape[1]
187187
else:
188188
kv_max_block_size = q_max_block_size
189-
if flash_block_sizes:
189+
# ensure that for cross attention we override the block sizes.
190+
if flash_block_sizes and key.shape[1] == query.shape[1]:
190191
block_sizes = flash_block_sizes
191192
else:
193+
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
192194
block_sizes = splash_attention_kernel.BlockSizes(
193-
block_q=min(q_max_block_size, query.shape[2]),
195+
block_q=block_size_q,
194196
block_kv_compute=min(kv_max_block_size, key.shape[2]),
195197
block_kv=min(kv_max_block_size, key.shape[2]),
196-
block_q_dkv=min(q_max_block_size, query.shape[2]),
198+
block_q_dkv=block_size_q,
197199
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
198200
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
199-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq,
201+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
200202
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
201203
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
202204
)

0 commit comments

Comments
 (0)