Skip to content

Commit e3256d7

Browse files
committed
fix t2t vit having two layernorms, and make final layernorm in distillation wrapper configurable, default to False for vit
1 parent 90be723 commit e3256d7

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.6.9',
9+
version = '1.7.0',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/distill.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2-
import torch.nn.functional as F
32
from torch import nn
3+
from torch.nn import Module
4+
import torch.nn.functional as F
5+
46
from vit_pytorch.vit import ViT
57
from vit_pytorch.t2t import T2TViT
68
from vit_pytorch.efficient import ViT as EfficientViT
@@ -12,6 +14,9 @@
1214
def exists(val):
1315
return val is not None
1416

17+
def default(val, d):
18+
return val if exists(val) else d
19+
1520
# classes
1621

1722
class DistillMixin:
@@ -20,12 +25,12 @@ def forward(self, img, distill_token = None):
2025
x = self.to_patch_embedding(img)
2126
b, n, _ = x.shape
2227

23-
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
28+
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
2429
x = torch.cat((cls_tokens, x), dim = 1)
2530
x += self.pos_embedding[:, :(n + 1)]
2631

2732
if distilling:
28-
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
33+
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
2934
x = torch.cat((x, distill_tokens), dim = 1)
3035

3136
x = self._attend(x)
@@ -97,15 +102,16 @@ def _attend(self, x):
97102

98103
# knowledge distillation wrapper
99104

100-
class DistillWrapper(nn.Module):
105+
class DistillWrapper(Module):
101106
def __init__(
102107
self,
103108
*,
104109
teacher,
105110
student,
106111
temperature = 1.,
107112
alpha = 0.5,
108-
hard = False
113+
hard = False,
114+
mlp_layernorm = False
109115
):
110116
super().__init__()
111117
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
@@ -122,14 +128,14 @@ def __init__(
122128
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
123129

124130
self.distill_mlp = nn.Sequential(
125-
nn.LayerNorm(dim),
131+
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
126132
nn.Linear(dim, num_classes)
127133
)
128134

129135
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
130-
b, *_ = img.shape
131-
alpha = alpha if exists(alpha) else self.alpha
132-
T = temperature if exists(temperature) else self.temperature
136+
137+
alpha = default(alpha, self.alpha)
138+
T = default(temperature, self.temperature)
133139

134140
with torch.no_grad():
135141
teacher_logits = self.teacher(img)

vit_pytorch/t2t.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,7 @@ def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None,
6161
self.pool = pool
6262
self.to_latent = nn.Identity()
6363

64-
self.mlp_head = nn.Sequential(
65-
nn.LayerNorm(dim),
66-
nn.Linear(dim, num_classes)
67-
)
64+
self.mlp_head = nn.Linear(dim, num_classes)
6865

6966
def forward(self, img):
7067
x = self.to_patch_embedding(img)

0 commit comments

Comments
 (0)