@@ -99,8 +99,8 @@ def forward(
9999 self ,
100100 x ,
101101 context = None ,
102- return_attn = False ,
103- attn = None
102+ return_qk_sim = False ,
103+ qk_sim = None
104104 ):
105105 x = self .norm (x )
106106
@@ -119,20 +119,21 @@ def forward(
119119 q , k = tuple (self .split_heads (t ) for t in qk )
120120
121121 q = q * self .scale
122- sim = einsum (q , k , 'b h i d, b h j d -> b h i j' )
122+ qk_sim = einsum (q , k , 'b h i d, b h j d -> b h i j' )
123123
124- attn = self .attend (sim )
125- attn = self .dropout (attn )
126124 else :
127- assert exists (attn ), 'attention matrix must be passed in for reusing previous attention'
125+ assert exists (qk_sim ), 'qk sim matrix must be passed in for reusing previous attention'
126+
127+ attn = self .attend (qk_sim )
128+ attn = self .dropout (attn )
128129
129130 out = einsum (attn , v , 'b h i j, b h j d -> b h i d' )
130131 out = self .to_out (out )
131132
132- if not return_attn :
133+ if not return_qk_sim :
133134 return out
134135
135- return out , attn
136+ return out , qk_sim
136137
137138# LookViT
138139
@@ -228,17 +229,17 @@ def forward(self, img):
228229
229230 # main tokens cross attends (lookup) on the high res tokens
230231
231- lookup_out , lookup_attn = lookup_cross_attn (tokens , highres_tokens , return_attn = True ) # return attention as they reuse the attention matrix
232+ lookup_out , qk_sim = lookup_cross_attn (tokens , highres_tokens , return_qk_sim = True ) # return attention as they reuse the attention matrix
232233 tokens = lookup_out + tokens
233234
234235 tokens = attn (tokens ) + tokens
235236 tokens = mlp (tokens ) + tokens
236237
237238 # attention-reuse
238239
239- lookup_attn = rearrange (lookup_attn , 'b h i j -> b h j i' ) # transpose for reverse cross attention
240+ qk_sim = rearrange (qk_sim , 'b h i j -> b h j i' ) # transpose for reverse cross attention
240241
241- highres_tokens = highres_attn (highres_tokens , tokens , attn = lookup_attn ) + highres_tokens
242+ highres_tokens = highres_attn (highres_tokens , tokens , qk_sim = qk_sim ) + highres_tokens
242243 highres_tokens = highres_norm (highres_tokens )
243244
244245 highres_tokens = highres_mlp (highres_tokens ) + highres_tokens
0 commit comments