Skip to content

Commit 4e6a42a

Browse files
committed
correct need for post-attention dropout
1 parent 6d7298d commit 4e6a42a

20 files changed

+61
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.29.1',
6+
version = '0.30.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/ats_vit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_token
139139
self.scale = dim_head ** -0.5
140140

141141
self.attend = nn.Softmax(dim = -1)
142+
self.dropout = nn.Dropout(dropout)
143+
142144
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
143145

144146
self.output_num_tokens = output_num_tokens
@@ -163,6 +165,7 @@ def forward(self, x, *, mask):
163165
dots = dots.masked_fill(~dots_mask, mask_value)
164166

165167
attn = self.attend(dots)
168+
attn = self.dropout(attn)
166169

167170
sampled_token_ids = None
168171

vit_pytorch/cait.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
7676
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
7777

7878
self.attend = nn.Softmax(dim = -1)
79+
self.dropout = nn.Dropout(dropout)
7980

8081
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
8182
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
@@ -96,7 +97,10 @@ def forward(self, x, context = None):
9697
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
9798

9899
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
100+
99101
attn = self.attend(dots)
102+
attn = self.dropout(attn)
103+
100104
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
101105

102106
out = einsum('b h i j, b h j d -> b h i d', attn, v)

vit_pytorch/cross_vit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
4848
self.scale = dim_head ** -0.5
4949

5050
self.attend = nn.Softmax(dim = -1)
51+
self.dropout = nn.Dropout(dropout)
52+
5153
self.to_q = nn.Linear(dim, inner_dim, bias = False)
5254
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
5355

@@ -69,6 +71,7 @@ def forward(self, x, context = None, kv_include_self = False):
6971
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
7072

7173
attn = self.attend(dots)
74+
attn = self.dropout(attn)
7275

7376
out = einsum('b h i j, b h j d -> b h i d', attn, v)
7477
out = rearrange(out, 'b h n d -> b n (h d)')

vit_pytorch/crossformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def __init__(
9595
self.window_size = window_size
9696

9797
self.norm = LayerNorm(dim)
98+
99+
self.dropout = nn.Dropout(dropout)
100+
98101
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
99102
self.to_out = nn.Conv2d(inner_dim, dim, 1)
100103

@@ -151,6 +154,7 @@ def forward(self, x):
151154
# attend
152155

153156
attn = sim.softmax(dim = -1)
157+
attn = self.dropout(attn)
154158

155159
# merge heads
156160

vit_pytorch/cvt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, d
7676
self.scale = dim_head ** -0.5
7777

7878
self.attend = nn.Softmax(dim = -1)
79+
self.dropout = nn.Dropout(dropout)
7980

8081
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
8182
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
@@ -94,6 +95,7 @@ def forward(self, x):
9495
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
9596

9697
attn = self.attend(dots)
98+
attn = self.dropout(attn)
9799

98100
out = einsum('b i j, b j d -> b i d', attn, v)
99101
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)

vit_pytorch/deepvit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
4242

4343
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
4444

45+
self.dropout = nn.Dropout(dropout)
46+
4547
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
4648

4749
self.reattn_norm = nn.Sequential(
@@ -64,6 +66,7 @@ def forward(self, x):
6466

6567
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
6668
attn = dots.softmax(dim=-1)
69+
attn = self.dropout(attn)
6770

6871
# re-attention
6972

vit_pytorch/levit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, drop
5252
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
5353

5454
self.attend = nn.Softmax(dim = -1)
55+
self.dropout = nn.Dropout(dropout)
5556

5657
out_batch_norm = nn.BatchNorm2d(dim_out)
5758
nn.init.zeros_(out_batch_norm.weight)
@@ -100,6 +101,7 @@ def forward(self, x):
100101
dots = self.apply_pos_bias(dots)
101102

102103
attn = self.attend(dots)
104+
attn = self.dropout(attn)
103105

104106
out = einsum('b h i j, b h j d -> b h i d', attn, v)
105107
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)

vit_pytorch/local_vit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
7878
self.scale = dim_head ** -0.5
7979

8080
self.attend = nn.Softmax(dim = -1)
81+
self.dropout = nn.Dropout(dropout)
8182
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
8283

8384
self.to_out = nn.Sequential(
@@ -93,6 +94,7 @@ def forward(self, x):
9394
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
9495

9596
attn = self.attend(dots)
97+
attn = self.dropout(attn)
9698

9799
out = einsum('b h i j, b h j d -> b h i d', attn, v)
98100
out = rearrange(out, 'b h n d -> b n (h d)')

vit_pytorch/mobile_vit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
5454
self.scale = dim_head ** -0.5
5555

5656
self.attend = nn.Softmax(dim=-1)
57+
self.dropout = nn.Dropout(dropout)
58+
5759
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
5860

5961
self.to_out = nn.Sequential(
@@ -67,7 +69,10 @@ def forward(self, x):
6769
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
6870

6971
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
72+
7073
attn = self.attend(dots)
74+
attn = self.dropout(attn)
75+
7176
out = torch.matmul(attn, v)
7277
out = rearrange(out, 'b p h n d -> b p n (h d)')
7378
return self.to_out(out)

0 commit comments

Comments
 (0)