Skip to content

Commit 9cd56ff

Browse files
committed
CCT allow for rectangular images
1 parent 2aae406 commit 9cd56ff

File tree

3 files changed

+44
-35
lines changed

3 files changed

+44
-35
lines changed

README.md

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -253,22 +253,25 @@ You can use this with two methods
253253
import torch
254254
from vit_pytorch.cct import CCT
255255

256-
model = CCT(
257-
img_size=224,
258-
embedding_dim=384,
259-
n_conv_layers=2,
260-
kernel_size=7,
261-
stride=2,
262-
padding=3,
263-
pooling_kernel_size=3,
264-
pooling_stride=2,
265-
pooling_padding=1,
266-
num_layers=14,
267-
num_heads=6,
268-
mlp_radio=3.,
269-
num_classes=1000,
270-
positional_embedding='learnable', # ['sine', 'learnable', 'none']
271-
)
256+
cct = CCT(
257+
img_size = (224, 448),
258+
embedding_dim = 384,
259+
n_conv_layers = 2,
260+
kernel_size = 7,
261+
stride = 2,
262+
padding = 3,
263+
pooling_kernel_size = 3,
264+
pooling_stride = 2,
265+
pooling_padding = 1,
266+
num_layers = 14,
267+
num_heads = 6,
268+
mlp_radio = 3.,
269+
num_classes = 1000,
270+
positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
271+
)
272+
273+
img = torch.randn(1, 3, 224, 448)
274+
pred = cct(img) # (1, 1000)
272275
```
273276

274277
Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
@@ -279,23 +282,23 @@ and the embedding dimension.
279282
import torch
280283
from vit_pytorch.cct import cct_14
281284

282-
model = cct_14(
283-
img_size=224,
284-
n_conv_layers=1,
285-
kernel_size=7,
286-
stride=2,
287-
padding=3,
288-
pooling_kernel_size=3,
289-
pooling_stride=2,
290-
pooling_padding=1,
291-
num_classes=1000,
292-
positional_embedding='learnable', # ['sine', 'learnable', 'none']
293-
)
285+
cct = cct_14(
286+
img_size = 224,
287+
n_conv_layers = 1,
288+
kernel_size = 7,
289+
stride = 2,
290+
padding = 3,
291+
pooling_kernel_size = 3,
292+
pooling_stride = 2,
293+
pooling_padding = 1,
294+
num_classes = 1000,
295+
positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
296+
)
294297
```
298+
295299
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
296300
Repository</a> includes links to pretrained model checkpoints.
297301

298-
299302
## Cross ViT
300303

301304
<img src="./images/cross_vit.png" width="400px"></img>

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.29.0',
6+
version = '0.29.1',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/cct.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
# Pre-defined CCT Models
5+
# helpers
6+
7+
def pair(t):
8+
return t if isinstance(t, tuple) else (t, t)
9+
10+
# CCT Models
11+
612
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
713

814

@@ -55,8 +61,8 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
5561
padding=padding,
5662
*args, **kwargs)
5763

64+
# modules
5865

59-
# Modules
6066
class Attention(nn.Module):
6167
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
6268
super().__init__()
@@ -308,6 +314,7 @@ def __init__(self,
308314
pooling_padding=1,
309315
*args, **kwargs):
310316
super(CCT, self).__init__()
317+
img_height, img_width = pair(img_size)
311318

312319
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
313320
n_output_channels=embedding_dim,
@@ -324,8 +331,8 @@ def __init__(self,
324331

325332
self.classifier = TransformerClassifier(
326333
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
327-
height=img_size,
328-
width=img_size),
334+
height=img_height,
335+
width=img_width),
329336
embedding_dim=embedding_dim,
330337
seq_pool=True,
331338
dropout_rate=0.,
@@ -336,4 +343,3 @@ def __init__(self,
336343
def forward(self, x):
337344
x = self.tokenizer(x)
338345
return self.classifier(x)
339-

0 commit comments

Comments
 (0)