11import torch
2- import torch .nn .functional as F
32from torch import nn
3+ from torch .nn import Module
4+ import torch .nn .functional as F
5+
46from vit_pytorch .vit import ViT
57from vit_pytorch .t2t import T2TViT
68from vit_pytorch .efficient import ViT as EfficientViT
1214def exists (val ):
1315 return val is not None
1416
17+ def default (val , d ):
18+ return val if exists (val ) else d
19+
1520# classes
1621
1722class DistillMixin :
@@ -20,12 +25,12 @@ def forward(self, img, distill_token = None):
2025 x = self .to_patch_embedding (img )
2126 b , n , _ = x .shape
2227
23- cls_tokens = repeat (self .cls_token , '() n d -> b n d' , b = b )
28+ cls_tokens = repeat (self .cls_token , '1 n d -> b n d' , b = b )
2429 x = torch .cat ((cls_tokens , x ), dim = 1 )
2530 x += self .pos_embedding [:, :(n + 1 )]
2631
2732 if distilling :
28- distill_tokens = repeat (distill_token , '() n d -> b n d' , b = b )
33+ distill_tokens = repeat (distill_token , '1 n d -> b n d' , b = b )
2934 x = torch .cat ((x , distill_tokens ), dim = 1 )
3035
3136 x = self ._attend (x )
@@ -97,15 +102,16 @@ def _attend(self, x):
97102
98103# knowledge distillation wrapper
99104
100- class DistillWrapper (nn . Module ):
105+ class DistillWrapper (Module ):
101106 def __init__ (
102107 self ,
103108 * ,
104109 teacher ,
105110 student ,
106111 temperature = 1. ,
107112 alpha = 0.5 ,
108- hard = False
113+ hard = False ,
114+ mlp_layernorm = False
109115 ):
110116 super ().__init__ ()
111117 assert (isinstance (student , (DistillableViT , DistillableT2TViT , DistillableEfficientViT ))) , 'student must be a vision transformer'
@@ -122,14 +128,14 @@ def __init__(
122128 self .distillation_token = nn .Parameter (torch .randn (1 , 1 , dim ))
123129
124130 self .distill_mlp = nn .Sequential (
125- nn .LayerNorm (dim ),
131+ nn .LayerNorm (dim ) if mlp_layernorm else nn . Identity () ,
126132 nn .Linear (dim , num_classes )
127133 )
128134
129135 def forward (self , img , labels , temperature = None , alpha = None , ** kwargs ):
130- b , * _ = img . shape
131- alpha = alpha if exists (alpha ) else self .alpha
132- T = temperature if exists (temperature ) else self .temperature
136+
137+ alpha = default (alpha , self .alpha )
138+ T = default (temperature , self .temperature )
133139
134140 with torch .no_grad ():
135141 teacher_logits = self .teacher (img )
0 commit comments