@@ -186,17 +186,19 @@ def _tpu_flash_attention(
186186 kv_max_block_size = key .shape [1 ]
187187 else :
188188 kv_max_block_size = q_max_block_size
189- if flash_block_sizes :
189+ # ensure that for cross attention we override the block sizes.
190+ if flash_block_sizes and key .shape [1 ] == query .shape [1 ]:
190191 block_sizes = flash_block_sizes
191192 else :
193+ block_size_q = flash_block_sizes .block_q if flash_block_sizes else q_max_block_size
192194 block_sizes = splash_attention_kernel .BlockSizes (
193- block_q = min ( q_max_block_size , query . shape [ 2 ]) ,
195+ block_q = block_size_q ,
194196 block_kv_compute = min (kv_max_block_size , key .shape [2 ]),
195197 block_kv = min (kv_max_block_size , key .shape [2 ]),
196- block_q_dkv = min ( q_max_block_size , query . shape [ 2 ]) ,
198+ block_q_dkv = block_size_q ,
197199 block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
198200 block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
199- block_q_dq = None if attention_kernel == "tokamax_flash" else block_sizes . block_q_dq ,
201+ block_q_dq = None if attention_kernel == "tokamax_flash" else block_size_q ,
200202 block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query .shape [2 ]),
201203 use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
202204 )
0 commit comments