@@ -28,7 +28,7 @@ def __init__(
2828 pixel_values_per_patch = self .patch_to_emb .weight .shape [- 1 ]
2929
3030 # decoder parameters
31-
31+ self . decoder_dim = decoder_dim
3232 self .enc_to_dec = nn .Linear (encoder_dim , decoder_dim ) if encoder_dim != decoder_dim else nn .Identity ()
3333 self .mask_token = nn .Parameter (torch .randn (decoder_dim ))
3434 self .decoder = Transformer (dim = decoder_dim , depth = decoder_depth , heads = decoder_heads , dim_head = decoder_dim_head , mlp_dim = decoder_dim * 4 )
@@ -73,21 +73,23 @@ def forward(self, img):
7373
7474 # reapply decoder position embedding to unmasked tokens
7575
76- decoder_tokens = decoder_tokens + self .decoder_pos_emb (unmasked_indices )
76+ unmasked_decoder_tokens = decoder_tokens + self .decoder_pos_emb (unmasked_indices )
7777
7878 # repeat mask tokens for number of masked, and add the positions using the masked indices derived above
7979
8080 mask_tokens = repeat (self .mask_token , 'd -> b n d' , b = batch , n = num_masked )
8181 mask_tokens = mask_tokens + self .decoder_pos_emb (masked_indices )
8282
8383 # concat the masked tokens to the decoder tokens and attend with decoder
84-
85- decoder_tokens = torch .cat ((mask_tokens , decoder_tokens ), dim = 1 )
84+
85+ decoder_tokens = torch .zeros (batch , num_patches , self .decoder_dim , device = device )
86+ decoder_tokens [batch_range , unmasked_indices ] = unmasked_decoder_tokens
87+ decoder_tokens [batch_range , masked_indices ] = mask_tokens
8688 decoded_tokens = self .decoder (decoder_tokens )
8789
8890 # splice out the mask tokens and project to pixel values
8991
90- mask_tokens = decoded_tokens [:, : num_masked ]
92+ mask_tokens = decoded_tokens [batch_range , masked_indices ]
9193 pred_pixel_values = self .to_pixels (mask_tokens )
9294
9395 # calculate reconstruction loss
0 commit comments