|
| 1 | +import torch |
| 2 | +from torch import nn |
| 3 | +import torch.nn.functional as F |
| 4 | +from torch.nn import Module, ModuleList |
| 5 | + |
| 6 | +from einops import einsum, rearrange, repeat, reduce |
| 7 | +from einops.layers.torch import Rearrange |
| 8 | + |
| 9 | +# helpers |
| 10 | + |
| 11 | +def exists(val): |
| 12 | + return val is not None |
| 13 | + |
| 14 | +def default(val, d): |
| 15 | + return val if exists(val) else d |
| 16 | + |
| 17 | +def divisible_by(num, den): |
| 18 | + return (num % den) == 0 |
| 19 | + |
| 20 | +# simple vit sinusoidal pos emb |
| 21 | + |
| 22 | +def posemb_sincos_2d(t, temperature = 10000): |
| 23 | + h, w, d, device = *t.shape[1:], t.device |
| 24 | + y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') |
| 25 | + assert (d % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" |
| 26 | + omega = torch.arange(d // 4, device = device) / (d // 4 - 1) |
| 27 | + omega = temperature ** -omega |
| 28 | + |
| 29 | + y = y.flatten()[:, None] * omega[None, :] |
| 30 | + x = x.flatten()[:, None] * omega[None, :] |
| 31 | + pos = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) |
| 32 | + |
| 33 | + return pos.float() |
| 34 | + |
| 35 | +# bias-less layernorm with unit offset trick (discovered by Ohad Rubin) |
| 36 | + |
| 37 | +class LayerNorm(Module): |
| 38 | + def __init__(self, dim): |
| 39 | + super().__init__() |
| 40 | + self.ln = nn.LayerNorm(dim, elementwise_affine = False) |
| 41 | + self.gamma = nn.Parameter(torch.zeros(dim)) |
| 42 | + |
| 43 | + def forward(self, x): |
| 44 | + normed = self.ln(x) |
| 45 | + return normed * (self.gamma + 1) |
| 46 | + |
| 47 | +# mlp |
| 48 | + |
| 49 | +def MLP(dim, factor = 4, dropout = 0.): |
| 50 | + hidden_dim = int(dim * factor) |
| 51 | + return nn.Sequential( |
| 52 | + LayerNorm(dim), |
| 53 | + nn.Linear(dim, hidden_dim), |
| 54 | + nn.GELU(), |
| 55 | + nn.Dropout(dropout), |
| 56 | + nn.Linear(hidden_dim, dim), |
| 57 | + nn.Dropout(dropout) |
| 58 | + ) |
| 59 | + |
| 60 | +# attention |
| 61 | + |
| 62 | +class Attention(Module): |
| 63 | + def __init__( |
| 64 | + self, |
| 65 | + dim, |
| 66 | + heads = 8, |
| 67 | + dim_head = 64, |
| 68 | + dropout = 0., |
| 69 | + reuse_attention = False |
| 70 | + ): |
| 71 | + super().__init__() |
| 72 | + inner_dim = dim_head * heads |
| 73 | + |
| 74 | + self.scale = dim_head ** -0.5 |
| 75 | + self.heads = heads |
| 76 | + self.reuse_attention = reuse_attention |
| 77 | + |
| 78 | + self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) |
| 79 | + |
| 80 | + self.norm = LayerNorm(dim) |
| 81 | + self.attend = nn.Softmax(dim = -1) |
| 82 | + self.dropout = nn.Dropout(dropout) |
| 83 | + |
| 84 | + self.to_q = nn.Linear(dim, inner_dim, bias = False) if not reuse_attention else None |
| 85 | + self.to_k = nn.Linear(dim, inner_dim, bias = False) if not reuse_attention else None |
| 86 | + self.to_v = nn.Linear(dim, inner_dim, bias = False) |
| 87 | + |
| 88 | + self.to_out = nn.Sequential( |
| 89 | + Rearrange('b h n d -> b n (h d)'), |
| 90 | + nn.Linear(inner_dim, dim, bias = False), |
| 91 | + nn.Dropout(dropout) |
| 92 | + ) |
| 93 | + |
| 94 | + def forward( |
| 95 | + self, |
| 96 | + x, |
| 97 | + context = None, |
| 98 | + return_attn = False, |
| 99 | + attn = None |
| 100 | + ): |
| 101 | + x = self.norm(x) |
| 102 | + context = default(context, x) |
| 103 | + |
| 104 | + v = self.to_v(context) |
| 105 | + v = self.split_heads(v) |
| 106 | + |
| 107 | + if not self.reuse_attention: |
| 108 | + qk = (self.to_q(x), self.to_k(context)) |
| 109 | + q, k = tuple(self.split_heads(t) for t in qk) |
| 110 | + |
| 111 | + q = q * self.scale |
| 112 | + sim = einsum(q, k, 'b h i d, b h j d -> b h i j') |
| 113 | + |
| 114 | + attn = self.attend(sim) |
| 115 | + attn = self.dropout(attn) |
| 116 | + else: |
| 117 | + assert exists(attn), 'attention matrix must be passed in for reusing previous attention' |
| 118 | + |
| 119 | + out = einsum(attn, v, 'b h i j, b h j d -> b h i d') |
| 120 | + out = self.to_out(out) |
| 121 | + |
| 122 | + if not return_attn: |
| 123 | + return out |
| 124 | + |
| 125 | + return out, attn |
| 126 | + |
| 127 | +# LookViT |
| 128 | + |
| 129 | +class LookViT(Module): |
| 130 | + def __init__( |
| 131 | + self, |
| 132 | + *, |
| 133 | + dim, |
| 134 | + image_size, |
| 135 | + num_classes, |
| 136 | + depth = 3, |
| 137 | + patch_size = 16, |
| 138 | + heads = 8, |
| 139 | + mlp_factor = 4, |
| 140 | + dim_head = 64, |
| 141 | + highres_patch_size = 12, |
| 142 | + highres_mlp_factor = 4, |
| 143 | + cross_attn_heads = 8, |
| 144 | + cross_attn_dim_head = 64, |
| 145 | + patch_conv_kernel_size = 7, |
| 146 | + dropout = 0.1, |
| 147 | + channels = 3 |
| 148 | + ): |
| 149 | + super().__init__() |
| 150 | + assert divisible_by(image_size, highres_patch_size) |
| 151 | + assert divisible_by(image_size, patch_size) |
| 152 | + assert patch_size > highres_patch_size, 'patch size of the main vision transformer should be smaller than the highres patch sizes (that does the `lookup`)' |
| 153 | + assert not divisible_by(patch_conv_kernel_size, 2) |
| 154 | + |
| 155 | + self.dim = dim |
| 156 | + self.image_size = image_size |
| 157 | + self.patch_size = patch_size |
| 158 | + |
| 159 | + kernel_size = patch_conv_kernel_size |
| 160 | + patch_dim = (highres_patch_size * highres_patch_size) * channels |
| 161 | + |
| 162 | + self.to_patches = nn.Sequential( |
| 163 | + Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = highres_patch_size, p2 = highres_patch_size), |
| 164 | + nn.Conv2d(patch_dim, dim, kernel_size, padding = kernel_size // 2), |
| 165 | + Rearrange('b c h w -> b h w c'), |
| 166 | + LayerNorm(dim), |
| 167 | + ) |
| 168 | + |
| 169 | + # absolute positions |
| 170 | + |
| 171 | + num_patches = (image_size // highres_patch_size) ** 2 |
| 172 | + self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim)) |
| 173 | + |
| 174 | + # lookvit blocks |
| 175 | + |
| 176 | + layers = ModuleList([]) |
| 177 | + |
| 178 | + for _ in range(depth): |
| 179 | + layers.append(ModuleList([ |
| 180 | + Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout), |
| 181 | + MLP(dim = dim, factor = mlp_factor, dropout = dropout), |
| 182 | + Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout), |
| 183 | + Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, reuse_attention = True), |
| 184 | + LayerNorm(dim), |
| 185 | + MLP(dim = dim, factor = highres_mlp_factor, dropout = dropout) |
| 186 | + ])) |
| 187 | + |
| 188 | + self.layers = layers |
| 189 | + |
| 190 | + self.norm = LayerNorm(dim) |
| 191 | + self.highres_norm = LayerNorm(dim) |
| 192 | + |
| 193 | + self.to_logits = nn.Linear(dim, num_classes, bias = False) |
| 194 | + |
| 195 | + def forward(self, img): |
| 196 | + assert img.shape[-2:] == (self.image_size, self.image_size) |
| 197 | + |
| 198 | + # to patch tokens and positions |
| 199 | + |
| 200 | + highres_tokens = self.to_patches(img) |
| 201 | + size = highres_tokens.shape[-2] |
| 202 | + |
| 203 | + pos_emb = posemb_sincos_2d(highres_tokens) |
| 204 | + highres_tokens = highres_tokens + rearrange(pos_emb, '(h w) d -> h w d', h = size) |
| 205 | + |
| 206 | + tokens = F.interpolate( |
| 207 | + rearrange(highres_tokens, 'b h w d -> b d h w'), |
| 208 | + img.shape[-1] // self.patch_size, |
| 209 | + mode = 'bilinear' |
| 210 | + ) |
| 211 | + |
| 212 | + tokens = rearrange(tokens, 'b c h w -> b (h w) c') |
| 213 | + highres_tokens = rearrange(highres_tokens, 'b h w c -> b (h w) c') |
| 214 | + |
| 215 | + # attention and feedforwards |
| 216 | + |
| 217 | + for attn, mlp, lookup_cross_attn, highres_attn, highres_norm, highres_mlp in self.layers: |
| 218 | + |
| 219 | + # main tokens cross attends (lookup) on the high res tokens |
| 220 | + |
| 221 | + lookup_out, lookup_attn = lookup_cross_attn(tokens, highres_tokens, return_attn = True) # return attention as they reuse the attention matrix |
| 222 | + tokens = lookup_out + tokens |
| 223 | + |
| 224 | + tokens = attn(tokens) + tokens |
| 225 | + tokens = mlp(tokens) + tokens |
| 226 | + |
| 227 | + # attention-reuse |
| 228 | + |
| 229 | + lookup_attn = rearrange(lookup_attn, 'b h i j -> b h j i') # transpose for reverse cross attention |
| 230 | + |
| 231 | + highres_tokens = highres_attn(highres_tokens, tokens, attn = lookup_attn) + highres_tokens |
| 232 | + highres_tokens = highres_norm(highres_tokens) |
| 233 | + |
| 234 | + highres_tokens = highres_mlp(highres_tokens) + highres_tokens |
| 235 | + |
| 236 | + # to logits |
| 237 | + |
| 238 | + tokens = self.norm(tokens) |
| 239 | + highres_tokens = self.highres_norm(highres_tokens) |
| 240 | + |
| 241 | + tokens = reduce(tokens, 'b n d -> b d', 'mean') |
| 242 | + highres_tokens = reduce(highres_tokens, 'b n d -> b d', 'mean') |
| 243 | + |
| 244 | + return self.to_logits(tokens + highres_tokens) |
| 245 | + |
| 246 | +# main |
| 247 | + |
| 248 | +if __name__ == '__main__': |
| 249 | + v = LookViT( |
| 250 | + image_size = 256, |
| 251 | + num_classes = 1000, |
| 252 | + dim = 512, |
| 253 | + depth = 2, |
| 254 | + heads = 8, |
| 255 | + dim_head = 64, |
| 256 | + patch_size = 32, |
| 257 | + highres_patch_size = 8, |
| 258 | + highres_mlp_factor = 2, |
| 259 | + cross_attn_heads = 8, |
| 260 | + cross_attn_dim_head = 64, |
| 261 | + dropout = 0.1 |
| 262 | + ).cuda() |
| 263 | + |
| 264 | + img = torch.randn(2, 3, 256, 256).cuda() |
| 265 | + pred = v(img) |
| 266 | + |
| 267 | + assert pred.shape == (2, 1000) |
0 commit comments