Skip to content

Commit d27721a

Browse files
committed
add scalable vit, from bytedance AI
1 parent cb22cbb commit d27721a

File tree

5 files changed

+347
-1
lines changed

5 files changed

+347
-1
lines changed

README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
- [Twins SVT](#twins-svt)
1919
- [CrossFormer](#crossformer)
2020
- [RegionViT](#regionvit)
21+
- [ScalableViT](#scalablevit)
2122
- [NesT](#nest)
2223
- [MobileViT](#mobilevit)
2324
- [Masked Autoencoder](#masked-autoencoder)
@@ -525,6 +526,38 @@ img = torch.randn(1, 3, 224, 224)
525526
pred = model(img) # (1, 1000)
526527
```
527528

529+
## ScalableViT
530+
531+
<img src="./images/scalable-vit-1.png" width="400px"></img>
532+
533+
<img src="./images/scalable-vit-2.png" width="400px"></img>
534+
535+
This Bytedance AI <a href="https://arxiv.org/abs/2203.10790">paper</a> proposes the Scalable Self Attention (SSA) and the Interactive Windowed Self Attention (IWSA) modules. The SSA alleviates the computation needed at earlier stages by reducing the key / value feature map by some factor (`reduction_factor`), while modulating the dimension of the queries and keys (`ssa_dim_key`). The IWSA performs self attention within local windows, similar to other vision transformer papers. However, they add a residual of the values, passed through a convolution of kernel size 3, which they named Local Interactive Module (LIM).
536+
537+
They make the claim in this paper that this scheme outperforms Swin Transformer, and also demonstrate competitive performance against Crossformer.
538+
539+
You can use it as follows (ex. ScalableViT-S)
540+
541+
```python
542+
import torch
543+
from vit_pytorch.scalable_vit import ScalableViT
544+
545+
model = ScalableViT(
546+
num_classes = 1000,
547+
dim = 64, # starting model dimension. at every stage, dimension is doubled
548+
heads = (2, 4, 8, 16), # number of attention heads at each stage
549+
depth = (2, 2, 20, 2), # number of transformer blocks at each stage
550+
ssa_dim_key = (40, 40, 40, 32), # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
551+
reduction_factor = (8, 4, 2, 1), # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
552+
window_size = (64, 32, None, None), # window size of the IWSA at each stage. None means no windowing needed
553+
dropout = 0.1, # attention and feedforward dropout
554+
).cuda()
555+
556+
img = torch.randn(1, 3, 256, 256).cuda()
557+
558+
preds = model(img) # (1, 1000)
559+
```
560+
528561
## NesT
529562

530563
<img src="./images/nest.png" width="400px"></img>
@@ -1352,6 +1385,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
13521385
}
13531386
```
13541387

1388+
```bibtex
1389+
@misc{yang2022scalablevit,
1390+
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
1391+
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
1392+
year = {2022},
1393+
eprint = {2203.10790},
1394+
archivePrefix = {arXiv},
1395+
primaryClass = {cs.CV}
1396+
}
1397+
```
1398+
13551399
```bibtex
13561400
@misc{vaswani2017attention,
13571401
title = {Attention Is All You Need},

images/scalable-vit-1.png

78.9 KB
Loading

images/scalable-vit-2.png

62 KB
Loading

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.27.1',
6+
version = '0.28.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/scalable_vit.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
from functools import partial
2+
import torch
3+
from torch import nn
4+
5+
from einops import rearrange, repeat
6+
from einops.layers.torch import Rearrange, Reduce
7+
8+
# helpers
9+
10+
def exists(val):
11+
return val is not None
12+
13+
def default(val, d):
14+
return val if exists(val) else d
15+
16+
def pair(t):
17+
return t if isinstance(t, tuple) else (t, t)
18+
19+
def cast_tuple(val, length = 1):
20+
return val if isinstance(val, tuple) else ((val,) * length)
21+
22+
# helper classes
23+
24+
class ChanLayerNorm(nn.Module):
25+
def __init__(self, dim, eps = 1e-5):
26+
super().__init__()
27+
self.eps = eps
28+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
29+
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
30+
31+
def forward(self, x):
32+
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
33+
mean = torch.mean(x, dim = 1, keepdim = True)
34+
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
35+
36+
class PreNorm(nn.Module):
37+
def __init__(self, dim, fn):
38+
super().__init__()
39+
self.norm = ChanLayerNorm(dim)
40+
self.fn = fn
41+
42+
def forward(self, x):
43+
return self.fn(self.norm(x))
44+
45+
class Downsample(nn.Module):
46+
def __init__(self, dim_in, dim_out):
47+
super().__init__()
48+
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
49+
50+
def forward(self, x):
51+
return self.conv(x)
52+
53+
class PEG(nn.Module):
54+
def __init__(self, dim, kernel_size = 3):
55+
super().__init__()
56+
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
57+
58+
def forward(self, x):
59+
return self.proj(x) + x
60+
61+
# feedforward
62+
63+
class FeedForward(nn.Module):
64+
def __init__(self, dim, expansion_factor = 4, dropout = 0.):
65+
super().__init__()
66+
inner_dim = dim * expansion_factor
67+
self.net = nn.Sequential(
68+
nn.Conv2d(dim, inner_dim, 1),
69+
nn.GELU(),
70+
nn.Dropout(dropout),
71+
nn.Conv2d(inner_dim, dim, 1),
72+
nn.Dropout(dropout)
73+
)
74+
def forward(self, x):
75+
return self.net(x)
76+
77+
# attention
78+
79+
class ScalableSelfAttention(nn.Module):
80+
def __init__(
81+
self,
82+
dim,
83+
heads = 8,
84+
dim_key = 64,
85+
dim_value = 64,
86+
dropout = 0.,
87+
reduction_factor = 1
88+
):
89+
super().__init__()
90+
self.heads = heads
91+
self.scale = dim_key ** -0.5
92+
self.attend = nn.Softmax(dim = -1)
93+
94+
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
95+
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
96+
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)
97+
98+
self.to_out = nn.Sequential(
99+
nn.Conv2d(dim_value * heads, dim, 1),
100+
nn.Dropout(dropout)
101+
)
102+
103+
def forward(self, x):
104+
height, width, heads = *x.shape[-2:], self.heads
105+
106+
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
107+
108+
# split out heads
109+
110+
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
111+
112+
# similarity
113+
114+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
115+
116+
# attention
117+
118+
attn = self.attend(dots)
119+
120+
# aggregate values
121+
122+
out = torch.matmul(attn, v)
123+
124+
# merge back heads
125+
126+
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = height, y = width)
127+
return self.to_out(out)
128+
129+
class InteractiveWindowedSelfAttention(nn.Module):
130+
def __init__(
131+
self,
132+
dim,
133+
window_size,
134+
heads = 8,
135+
dim_key = 64,
136+
dim_value = 64,
137+
dropout = 0.
138+
):
139+
super().__init__()
140+
self.heads = heads
141+
self.scale = dim_key ** -0.5
142+
self.window_size = window_size
143+
self.attend = nn.Softmax(dim = -1)
144+
145+
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
146+
147+
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
148+
self.to_k = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
149+
self.to_v = nn.Conv2d(dim, dim_value * heads, 1, bias = False)
150+
151+
self.to_out = nn.Sequential(
152+
nn.Conv2d(dim_value * heads, dim, 1),
153+
nn.Dropout(dropout)
154+
)
155+
156+
def forward(self, x):
157+
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
158+
159+
wsz = default(wsz, height) # take height as window size if not given
160+
assert (height % wsz) == 0 and (width % wsz) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz})'
161+
162+
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
163+
164+
# get output of LIM
165+
166+
local_out = self.local_interactive_module(v)
167+
168+
# divide into window (and split out heads) for efficient self attention
169+
170+
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz, w2 = wsz), (q, k, v))
171+
172+
# similarity
173+
174+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
175+
176+
# attention
177+
178+
attn = self.attend(dots)
179+
180+
# aggregate values
181+
182+
out = torch.matmul(attn, v)
183+
184+
# reshape the windows back to full feature map (and merge heads)
185+
186+
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
187+
188+
# add LIM output
189+
190+
out = out + local_out
191+
192+
return self.to_out(out)
193+
194+
class Transformer(nn.Module):
195+
def __init__(
196+
self,
197+
dim,
198+
depth,
199+
heads = 8,
200+
ff_expansion_factor = 4,
201+
dropout = 0.,
202+
ssa_dim_key = 64,
203+
ssa_dim_value = 64,
204+
ssa_reduction_factor = 1,
205+
iwsa_dim_key = 64,
206+
iwsa_dim_value = 64,
207+
iwsa_window_size = 64,
208+
norm_output = True
209+
):
210+
super().__init__()
211+
self.layers = nn.ModuleList([])
212+
for ind in range(depth):
213+
is_first = ind == 0
214+
215+
self.layers.append(nn.ModuleList([
216+
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
217+
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
218+
PEG(dim) if is_first else None,
219+
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
220+
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
221+
]))
222+
223+
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
224+
225+
def forward(self, x):
226+
for ssa, ff1, peg, iwsa, ff2 in self.layers:
227+
x = ssa(x) + x
228+
x = ff1(x) + x
229+
230+
if exists(peg):
231+
x = peg(x)
232+
233+
x = iwsa(x) + x
234+
x = ff2(x) + x
235+
236+
return self.norm(x)
237+
238+
class ScalableViT(nn.Module):
239+
def __init__(
240+
self,
241+
*,
242+
num_classes,
243+
dim,
244+
depth,
245+
heads,
246+
reduction_factor,
247+
ff_expansion_factor = 4,
248+
iwsa_dim_key = 64,
249+
iwsa_dim_value = 64,
250+
window_size = 64,
251+
ssa_dim_key = 64,
252+
ssa_dim_value = 64,
253+
channels = 3,
254+
dropout = 0.
255+
):
256+
super().__init__()
257+
self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)
258+
259+
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
260+
261+
num_stages = len(depth)
262+
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
263+
264+
hyperparams_per_stage = [
265+
heads,
266+
ssa_dim_key,
267+
ssa_dim_value,
268+
reduction_factor,
269+
iwsa_dim_key,
270+
iwsa_dim_value,
271+
window_size,
272+
]
273+
274+
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
275+
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
276+
277+
self.layers = nn.ModuleList([])
278+
279+
for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)):
280+
is_last = ind == (num_stages - 1)
281+
282+
self.layers.append(nn.ModuleList([
283+
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size),
284+
Downsample(layer_dim, layer_dim * 2) if not is_last else None
285+
]))
286+
287+
self.mlp_head = nn.Sequential(
288+
Reduce('b d h w -> b d', 'mean'),
289+
nn.LayerNorm(dims[-1]),
290+
nn.Linear(dims[-1], num_classes)
291+
)
292+
293+
def forward(self, img):
294+
x = self.to_patches(img)
295+
296+
for transformer, downsample in self.layers:
297+
x = transformer(x)
298+
299+
if exists(downsample):
300+
x = downsample(x)
301+
302+
return self.mlp_head(x)

0 commit comments

Comments
 (0)