Skip to content

Commit 24196a3

Browse files
committed
allow for qk norm to be turned off for na vit nested tensor
1 parent f6d7287 commit 24196a3

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.8.7',
9+
version = '1.8.8',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/na_vit_nested_tensor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
4141
)
4242

4343
class Attention(Module):
44-
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
44+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
4545
super().__init__()
4646
self.norm = nn.LayerNorm(dim, bias = False)
4747

@@ -56,8 +56,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
5656
# in the paper, they employ qk rmsnorm, a way to stabilize attention
5757
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
5858

59-
self.query_norm = nn.LayerNorm(dim_head, bias = False)
60-
self.key_norm = nn.LayerNorm(dim_head, bias = False)
59+
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
60+
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
6161

6262
self.dropout = dropout
6363

@@ -111,13 +111,13 @@ def transpose_head_seq(t):
111111
return self.to_out(out)
112112

113113
class Transformer(Module):
114-
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
114+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
115115
super().__init__()
116116
self.layers = ModuleList([])
117117

118118
for _ in range(depth):
119119
self.layers.append(ModuleList([
120-
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
120+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
121121
FeedForward(dim, mlp_dim, dropout = dropout)
122122
]))
123123

@@ -146,6 +146,7 @@ def __init__(
146146
dim_head = 64,
147147
dropout = 0.,
148148
emb_dropout = 0.,
149+
qk_rmsnorm = True,
149150
token_dropout_prob: float | None = None
150151
):
151152
super().__init__()
@@ -184,7 +185,7 @@ def __init__(
184185

185186
self.dropout = nn.Dropout(emb_dropout)
186187

187-
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
188+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
188189

189190
# final attention pooling queries
190191

vit_pytorch/na_vit_nested_tensor_3d.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
4141
)
4242

4343
class Attention(Module):
44-
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
44+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
4545
super().__init__()
4646
self.norm = nn.LayerNorm(dim, bias = False)
4747

@@ -56,8 +56,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
5656
# in the paper, they employ qk rmsnorm, a way to stabilize attention
5757
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
5858

59-
self.query_norm = nn.LayerNorm(dim_head, bias = False)
60-
self.key_norm = nn.LayerNorm(dim_head, bias = False)
59+
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
60+
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
6161

6262
self.dropout = dropout
6363

@@ -123,13 +123,13 @@ def transpose_head_seq(t):
123123
return self.to_out(out)
124124

125125
class Transformer(Module):
126-
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
126+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
127127
super().__init__()
128128
self.layers = ModuleList([])
129129

130130
for _ in range(depth):
131131
self.layers.append(ModuleList([
132-
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
132+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
133133
FeedForward(dim, mlp_dim, dropout = dropout)
134134
]))
135135

@@ -161,6 +161,7 @@ def __init__(
161161
dropout = 0.,
162162
emb_dropout = 0.,
163163
num_registers = 4,
164+
qk_rmsnorm = True,
164165
token_dropout_prob: float | None = None
165166
):
166167
super().__init__()
@@ -209,7 +210,7 @@ def __init__(
209210

210211
self.dropout = nn.Dropout(emb_dropout)
211212

212-
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
213+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
213214

214215
# final attention pooling queries
215216

0 commit comments

Comments
 (0)