@@ -186,17 +186,19 @@ def _tpu_flash_attention(
186186 kv_max_block_size = key .shape [1 ]
187187 else :
188188 kv_max_block_size = q_max_block_size
189- if flash_block_sizes :
189+ # ensure that for cross attention we override the block sizes.
190+ if flash_block_sizes and key .shape [1 ] == query .shape [1 ]:
190191 block_sizes = flash_block_sizes
191192 else :
193+ block_size_q = flash_block_sizes .block_q if flash_block_sizes else q_max_block_size
192194 block_sizes = splash_attention_kernel .BlockSizes (
193- block_q = min ( q_max_block_size , query . shape [ 2 ]) ,
195+ block_q = block_size_q ,
194196 block_kv_compute = min (kv_max_block_size , key .shape [2 ]),
195197 block_kv = min (kv_max_block_size , key .shape [2 ]),
196- block_q_dkv = min ( q_max_block_size , query . shape [ 2 ]) ,
198+ block_q_dkv = block_size_q ,
197199 block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
198200 block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
199- block_q_dq = None if attention_kernel == "tokamax_flash" else block_sizes . block_q_dq ,
201+ block_q_dq = None if attention_kernel == "tokamax_flash" else block_size_q ,
200202 block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query .shape [2 ]),
201203 use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
202204 )
@@ -215,7 +217,6 @@ def _tpu_flash_attention(
215217 check_rep = False ,
216218 )
217219 def wrap_flash_attention (query , key , value ):
218-
219220 uses_fused_kernel = block_sizes .use_fused_bwd_kernel
220221 block_q_sizes = (
221222 block_sizes .block_q ,
@@ -1042,7 +1043,6 @@ def setup(self):
10421043 )
10431044
10441045 def __call__ (self , hidden_states , encoder_hidden_states = None , attention_mask = None , image_rotary_emb = None ):
1045-
10461046 qkv_proj = self .qkv (hidden_states )
10471047 B , L = hidden_states .shape [:2 ]
10481048 H , D , K = self .heads , qkv_proj .shape [- 1 ] // (self .heads * 3 ), 3
@@ -1054,7 +1054,6 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non
10541054 key_proj = self .key_norm (key_proj )
10551055
10561056 if encoder_hidden_states is not None :
1057-
10581057 encoder_qkv_proj = self .encoder_qkv (encoder_hidden_states )
10591058 B , L = encoder_hidden_states .shape [:2 ]
10601059 H , D , K = self .heads , encoder_qkv_proj .shape [- 1 ] // (self .heads * 3 ), 3
@@ -1148,7 +1147,6 @@ class FlaxAttention(nn.Module):
11481147 quant : Quant = None
11491148
11501149 def setup (self ):
1151-
11521150 if self .attention_kernel == "flash" and self .mesh is None :
11531151 raise ValueError (f"The flash attention kernel requires a value for mesh, but mesh is { self .mesh } " )
11541152 inner_dim = self .dim_head * self .heads
0 commit comments