Skip to content

Commit b4853d3

Browse files
committed
add the 3d simple vit
1 parent 29fbf0a commit b4853d3

File tree

3 files changed

+153
-2
lines changed

3 files changed

+153
-2
lines changed

README.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ By popular request, I will start extending a few of the architectures in this re
974974

975975
You will need to pass in two additional hyperparameters: (1) the number of frames `frames` and (2) patch size along the frame dimension `frame_patch_size`
976976

977-
For starters, with the most basic ViT
977+
For starters, 3D ViT
978978

979979
```python
980980
import torch
@@ -999,6 +999,29 @@ video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, widt
999999
preds = v(video) # (4, 1000)
10001000
```
10011001

1002+
3D Simple ViT
1003+
1004+
```python
1005+
import torch
1006+
from vit_pytorch.simple_vit_3d import SimpleViT
1007+
1008+
v = SimpleViT(
1009+
image_size = 128, # image size
1010+
frames = 16, # number of frames
1011+
image_patch_size = 16, # image patch size
1012+
frame_patch_size = 2, # frame patch size
1013+
num_classes = 1000,
1014+
dim = 1024,
1015+
depth = 6,
1016+
heads = 8,
1017+
mlp_dim = 2048
1018+
)
1019+
1020+
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
1021+
1022+
preds = v(video) # (4, 1000)
1023+
```
1024+
10021025
## Parallel ViT
10031026

10041027
<img src="./images/parallel-vit.png" width="350px"></img>

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.36.0',
6+
version = '0.36.1',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/simple_vit_3d.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch import nn
4+
5+
from einops import rearrange
6+
from einops.layers.torch import Rearrange
7+
8+
# helpers
9+
10+
def pair(t):
11+
return t if isinstance(t, tuple) else (t, t)
12+
13+
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
14+
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
15+
16+
z, y, x = torch.meshgrid(
17+
torch.arange(f, device = device),
18+
torch.arange(h, device = device),
19+
torch.arange(w, device = device),
20+
indexing = 'ij')
21+
22+
fourier_dim = dim // 6
23+
24+
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
25+
omega = 1. / (temperature ** omega)
26+
27+
z = z.flatten()[:, None] * omega[None, :]
28+
y = y.flatten()[:, None] * omega[None, :]
29+
x = x.flatten()[:, None] * omega[None, :]
30+
31+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
32+
33+
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
34+
return pe.type(dtype)
35+
36+
# classes
37+
38+
class FeedForward(nn.Module):
39+
def __init__(self, dim, hidden_dim):
40+
super().__init__()
41+
self.net = nn.Sequential(
42+
nn.LayerNorm(dim),
43+
nn.Linear(dim, hidden_dim),
44+
nn.GELU(),
45+
nn.Linear(hidden_dim, dim),
46+
)
47+
def forward(self, x):
48+
return self.net(x)
49+
50+
class Attention(nn.Module):
51+
def __init__(self, dim, heads = 8, dim_head = 64):
52+
super().__init__()
53+
inner_dim = dim_head * heads
54+
self.heads = heads
55+
self.scale = dim_head ** -0.5
56+
self.norm = nn.LayerNorm(dim)
57+
58+
self.attend = nn.Softmax(dim = -1)
59+
60+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
61+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
62+
63+
def forward(self, x):
64+
x = self.norm(x)
65+
66+
qkv = self.to_qkv(x).chunk(3, dim = -1)
67+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
68+
69+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
70+
71+
attn = self.attend(dots)
72+
73+
out = torch.matmul(attn, v)
74+
out = rearrange(out, 'b h n d -> b n (h d)')
75+
return self.to_out(out)
76+
77+
class Transformer(nn.Module):
78+
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
79+
super().__init__()
80+
self.layers = nn.ModuleList([])
81+
for _ in range(depth):
82+
self.layers.append(nn.ModuleList([
83+
Attention(dim, heads = heads, dim_head = dim_head),
84+
FeedForward(dim, mlp_dim)
85+
]))
86+
def forward(self, x):
87+
for attn, ff in self.layers:
88+
x = attn(x) + x
89+
x = ff(x) + x
90+
return x
91+
92+
class SimpleViT(nn.Module):
93+
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
94+
super().__init__()
95+
image_height, image_width = pair(image_size)
96+
patch_height, patch_width = pair(image_patch_size)
97+
98+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
99+
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
100+
101+
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
102+
patch_dim = channels * patch_height * patch_width * frame_patch_size
103+
104+
self.to_patch_embedding = nn.Sequential(
105+
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
106+
nn.Linear(patch_dim, dim),
107+
)
108+
109+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
110+
111+
self.to_latent = nn.Identity()
112+
self.linear_head = nn.Sequential(
113+
nn.LayerNorm(dim),
114+
nn.Linear(dim, num_classes)
115+
)
116+
117+
def forward(self, img):
118+
*_, h, w, dtype = *img.shape, img.dtype
119+
120+
x = self.to_patch_embedding(img)
121+
pe = posemb_sincos_3d(x)
122+
x = rearrange(x, 'b ... d -> b (...) d') + pe
123+
124+
x = self.transformer(x)
125+
x = x.mean(dim = 1)
126+
127+
x = self.to_latent(x)
128+
return self.linear_head(x)

0 commit comments

Comments
 (0)