Skip to content

Commit 9a95e79

Browse files
authored
Update mae.py (#242)
update mae so decoded tokens can be easily reshaped back to visualize the reconstruction
1 parent b4853d3 commit 9a95e79

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

vit_pytorch/mae.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
2929

3030
# decoder parameters
31-
31+
self.decoder_dim = decoder_dim
3232
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
3333
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
3434
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
@@ -73,21 +73,23 @@ def forward(self, img):
7373

7474
# reapply decoder position embedding to unmasked tokens
7575

76-
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
76+
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
7777

7878
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
7979

8080
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
8181
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
8282

8383
# concat the masked tokens to the decoder tokens and attend with decoder
84-
85-
decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1)
84+
85+
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
86+
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
87+
decoder_tokens[batch_range, masked_indices] = mask_tokens
8688
decoded_tokens = self.decoder(decoder_tokens)
8789

8890
# splice out the mask tokens and project to pixel values
8991

90-
mask_tokens = decoded_tokens[:, :num_masked]
92+
mask_tokens = decoded_tokens[batch_range, masked_indices]
9193
pred_pixel_values = self.to_pixels(mask_tokens)
9294

9395
# calculate reconstruction loss

0 commit comments

Comments
 (0)