Skip to content

Commit 29fbf0a

Browse files
committed
begin extending some of the architectures over to 3d, starting with basic ViT
1 parent 4b8f5bc commit 29fbf0a

File tree

3 files changed

+162
-1
lines changed

3 files changed

+162
-1
lines changed

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
- [Adaptive Token Sampling](#adaptive-token-sampling)
3131
- [Patch Merger](#patch-merger)
3232
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
33+
- [3D Vit](#3d-vit)
3334
- [Parallel ViT](#parallel-vit)
3435
- [Learnable Memory ViT](#learnable-memory-vit)
3536
- [Dino](#dino)
@@ -967,6 +968,37 @@ img = torch.randn(4, 3, 256, 256)
967968
tokens = spt(img) # (4, 256, 1024)
968969
```
969970

971+
## 3D ViT
972+
973+
By popular request, I will start extending a few of the architectures in this repository to 3D ViTs, for use with video, medical imaging, etc.
974+
975+
You will need to pass in two additional hyperparameters: (1) the number of frames `frames` and (2) patch size along the frame dimension `frame_patch_size`
976+
977+
For starters, with the most basic ViT
978+
979+
```python
980+
import torch
981+
from vit_pytorch.vit_3d import ViT
982+
983+
v = ViT(
984+
image_size = 128, # image size
985+
frames = 16, # number of frames
986+
image_patch_size = 16, # image patch size
987+
frame_patch_size = 2, # frame patch size
988+
num_classes = 1000,
989+
dim = 1024,
990+
depth = 6,
991+
heads = 8,
992+
mlp_dim = 2048,
993+
dropout = 0.1,
994+
emb_dropout = 0.1
995+
)
996+
997+
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
998+
999+
preds = v(video) # (4, 1000)
1000+
```
1001+
9701002
## Parallel ViT
9711003

9721004
<img src="./images/parallel-vit.png" width="350px"></img>

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.35.8',
6+
version = '0.36.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/vit_3d.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import torch
2+
from torch import nn
3+
4+
from einops import rearrange, repeat
5+
from einops.layers.torch import Rearrange
6+
7+
# helpers
8+
9+
def pair(t):
10+
return t if isinstance(t, tuple) else (t, t)
11+
12+
# classes
13+
14+
class PreNorm(nn.Module):
15+
def __init__(self, dim, fn):
16+
super().__init__()
17+
self.norm = nn.LayerNorm(dim)
18+
self.fn = fn
19+
def forward(self, x, **kwargs):
20+
return self.fn(self.norm(x), **kwargs)
21+
22+
class FeedForward(nn.Module):
23+
def __init__(self, dim, hidden_dim, dropout = 0.):
24+
super().__init__()
25+
self.net = nn.Sequential(
26+
nn.Linear(dim, hidden_dim),
27+
nn.GELU(),
28+
nn.Dropout(dropout),
29+
nn.Linear(hidden_dim, dim),
30+
nn.Dropout(dropout)
31+
)
32+
def forward(self, x):
33+
return self.net(x)
34+
35+
class Attention(nn.Module):
36+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
37+
super().__init__()
38+
inner_dim = dim_head * heads
39+
project_out = not (heads == 1 and dim_head == dim)
40+
41+
self.heads = heads
42+
self.scale = dim_head ** -0.5
43+
44+
self.attend = nn.Softmax(dim = -1)
45+
self.dropout = nn.Dropout(dropout)
46+
47+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
48+
49+
self.to_out = nn.Sequential(
50+
nn.Linear(inner_dim, dim),
51+
nn.Dropout(dropout)
52+
) if project_out else nn.Identity()
53+
54+
def forward(self, x):
55+
qkv = self.to_qkv(x).chunk(3, dim = -1)
56+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
57+
58+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
59+
60+
attn = self.attend(dots)
61+
attn = self.dropout(attn)
62+
63+
out = torch.matmul(attn, v)
64+
out = rearrange(out, 'b h n d -> b n (h d)')
65+
return self.to_out(out)
66+
67+
class Transformer(nn.Module):
68+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
69+
super().__init__()
70+
self.layers = nn.ModuleList([])
71+
for _ in range(depth):
72+
self.layers.append(nn.ModuleList([
73+
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
74+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
75+
]))
76+
def forward(self, x):
77+
for attn, ff in self.layers:
78+
x = attn(x) + x
79+
x = ff(x) + x
80+
return x
81+
82+
class ViT(nn.Module):
83+
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
84+
super().__init__()
85+
image_height, image_width = pair(image_size)
86+
patch_height, patch_width = pair(image_patch_size)
87+
88+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
89+
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
90+
91+
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
92+
patch_dim = channels * patch_height * patch_width * frame_patch_size
93+
94+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
95+
96+
self.to_patch_embedding = nn.Sequential(
97+
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
98+
nn.Linear(patch_dim, dim),
99+
)
100+
101+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
102+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
103+
self.dropout = nn.Dropout(emb_dropout)
104+
105+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
106+
107+
self.pool = pool
108+
self.to_latent = nn.Identity()
109+
110+
self.mlp_head = nn.Sequential(
111+
nn.LayerNorm(dim),
112+
nn.Linear(dim, num_classes)
113+
)
114+
115+
def forward(self, img):
116+
x = self.to_patch_embedding(img)
117+
b, n, _ = x.shape
118+
119+
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
120+
x = torch.cat((cls_tokens, x), dim=1)
121+
x += self.pos_embedding[:, :(n + 1)]
122+
x = self.dropout(x)
123+
124+
x = self.transformer(x)
125+
126+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
127+
128+
x = self.to_latent(x)
129+
return self.mlp_head(x)

0 commit comments

Comments
 (0)