Skip to content

Commit 7a4619e

Browse files
committed
Trying out differential + parallel block, and fused out project option for ParallelScalingBlock to see if it's worthwhile now...
1 parent d66283d commit 7a4619e

File tree

1 file changed

+200
-9
lines changed

1 file changed

+200
-9
lines changed

timm/models/vision_transformer.py

Lines changed: 200 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,9 @@ def __init__(
303303
mlp_layer: Optional[Type[nn.Module]] = None, # not used
304304
attn_layer: Optional[LayerType] = None, # not used
305305
depth: int = 0, # not used
306-
device = None,
307-
dtype = None,
306+
fuse_out_proj: bool = False,
307+
device=None,
308+
dtype=None,
308309
) -> None:
309310
super().__init__()
310311
dd = {'device': device, 'dtype': dtype}
@@ -330,11 +331,20 @@ def __init__(
330331
self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
331332
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
332333
self.attn_drop = nn.Dropout(attn_drop)
333-
self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
334334

335335
self.mlp_drop = nn.Dropout(proj_drop)
336336
self.mlp_act = act_layer()
337-
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias, **dd)
337+
338+
if fuse_out_proj:
339+
# Fused output projection for both attention and MLP
340+
self.out_proj = nn.Linear(dim + mlp_hidden_dim, dim, bias=proj_bias, **dd)
341+
self.attn_out_proj = None
342+
self.mlp_out_proj = None
343+
else:
344+
# Separate output projections
345+
self.out_proj = None
346+
self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
347+
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias, **dd)
338348

339349
self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity()
340350
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@@ -371,16 +381,184 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) ->
371381
x_attn = attn @ v
372382

373383
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
374-
x_attn = self.attn_out_proj(x_attn)
375384

376-
# MLP activation, dropout, fc2
385+
# MLP activation & dropout
377386
x_mlp = self.mlp_act(x_mlp)
378387
x_mlp = self.mlp_drop(x_mlp)
379-
x_mlp = self.mlp_out_proj(x_mlp)
388+
389+
# Output projection (fused or separate)
390+
if self.out_proj is not None:
391+
y = self.out_proj(torch.cat((x_attn, x_mlp), dim=-1))
392+
else:
393+
y = self.attn_out_proj(x_attn) + self.mlp_out_proj(x_mlp)
380394

381395
# Add residual w/ drop path & layer scale applied
382-
y = self.drop_path(self.ls(x_attn + x_mlp))
383-
x = x + y
396+
x = x + self.drop_path(self.ls(y))
397+
return x
398+
399+
400+
class DiffParallelScalingBlock(nn.Module):
401+
""" Parallel ViT block with Differential Attention (MLP & Attention in parallel).
402+
403+
Combines the parallel MLP+Attention structure from 'Scaling Vision Transformers to
404+
22 Billion Parameters' (https://arxiv.org/abs/2302.05442) with differential attention
405+
from 'Differential Transformer' (https://arxiv.org/abs/2410.05258).
406+
"""
407+
fused_attn: Final[bool]
408+
409+
def __init__(
410+
self,
411+
dim: int,
412+
num_heads: int,
413+
mlp_ratio: float = 4.,
414+
qkv_bias: bool = False,
415+
qk_norm: bool = False,
416+
scale_attn_norm: bool = False,
417+
scale_mlp_norm: bool = False,
418+
proj_bias: bool = True,
419+
proj_drop: float = 0.,
420+
attn_drop: float = 0.,
421+
init_values: Optional[float] = None,
422+
drop_path: float = 0.,
423+
act_layer: Type[nn.Module] = nn.GELU,
424+
norm_layer: Type[nn.Module] = LayerNorm,
425+
mlp_layer: Optional[Type[nn.Module]] = None,
426+
attn_layer: Optional[LayerType] = None,
427+
depth: int = 0,
428+
dual_lambda: bool = False,
429+
device=None,
430+
dtype=None,
431+
) -> None:
432+
super().__init__()
433+
dd = {'device': device, 'dtype': dtype}
434+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
435+
assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported'
436+
self.num_heads = num_heads
437+
self.head_dim = dim // num_heads // 2 # Half head_dim for diff attention
438+
self.scale = self.head_dim ** -0.5
439+
self.fused_attn = use_fused_attn()
440+
mlp_hidden_dim = int(mlp_ratio * dim)
441+
in_proj_out_dim = mlp_hidden_dim + 3 * dim
442+
443+
self.in_norm = norm_layer(dim, **dd)
444+
self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias, **dd)
445+
self.in_split = [mlp_hidden_dim] + [dim] * 3
446+
if qkv_bias:
447+
self.register_buffer('qkv_bias', None)
448+
self.register_parameter('mlp_bias', None)
449+
else:
450+
self.register_buffer('qkv_bias', torch.zeros(3 * dim, **dd), persistent=False)
451+
self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim, **dd))
452+
453+
self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
454+
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
455+
self.attn_drop = nn.Dropout(attn_drop)
456+
self.attn_drop_p = attn_drop
457+
458+
# Differential attention specific
459+
self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5, **dd)
460+
self.dual_lambda = dual_lambda
461+
if dual_lambda:
462+
self.lambda_a = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
463+
self.lambda_b = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
464+
self.lambda_q1 = self.lambda_k1 = self.lambda_q2 = self.lambda_k2 = None
465+
else:
466+
self.lambda_a = self.lambda_b = None
467+
self.lambda_q1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
468+
self.lambda_k1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
469+
self.lambda_q2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
470+
self.lambda_k2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
471+
472+
self.mlp_drop = nn.Dropout(proj_drop)
473+
self.mlp_act = act_layer()
474+
475+
# Fused output projection for both attention and MLP
476+
self.out_proj = nn.Linear(dim + mlp_hidden_dim, dim, bias=proj_bias, **dd)
477+
478+
self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity()
479+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
480+
481+
self.lambda_init = 0.8
482+
self.set_lambda_init(depth)
483+
self.reset_parameters()
484+
485+
def set_lambda_init(self, depth: int):
486+
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
487+
488+
def reset_parameters(self):
489+
if self.dual_lambda:
490+
nn.init.zeros_(self.lambda_a)
491+
nn.init.zeros_(self.lambda_b)
492+
else:
493+
nn.init.normal_(self.lambda_q1, mean=0, std=0.1)
494+
nn.init.normal_(self.lambda_k1, mean=0, std=0.1)
495+
nn.init.normal_(self.lambda_q2, mean=0, std=0.1)
496+
nn.init.normal_(self.lambda_k2, mean=0, std=0.1)
497+
498+
def _compute_lambda(self) -> torch.Tensor:
499+
if self.lambda_a is not None:
500+
lambda_1 = torch.exp(self.lambda_a)
501+
lambda_2 = torch.exp(self.lambda_b)
502+
else:
503+
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float())
504+
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float())
505+
return lambda_1 - lambda_2 + self.lambda_init
506+
507+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
508+
B, N, C = x.shape
509+
510+
# Combined MLP fc1 & qkv projections
511+
y = self.in_norm(x)
512+
if self.mlp_bias is not None:
513+
y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)))
514+
else:
515+
y = self.in_proj(y)
516+
x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
517+
518+
# Reshape for differential attention (2x heads with half head_dim for q/k)
519+
q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
520+
k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
521+
v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2)
522+
523+
q, k = self.q_norm(q), self.k_norm(k)
524+
525+
lambda_full = self._compute_lambda().type_as(q)
526+
527+
if self.fused_attn:
528+
q = q.reshape(B, self.num_heads, 2, N, self.head_dim)
529+
k = k.reshape(B, self.num_heads, 2, N, self.head_dim)
530+
q1, q2 = q.unbind(2)
531+
k1, k2 = k.unbind(2)
532+
533+
dropout_p = self.attn_drop_p if self.training else 0.0
534+
attn1 = F.scaled_dot_product_attention(q1, k1, v, attn_mask=attn_mask, dropout_p=dropout_p)
535+
attn2 = F.scaled_dot_product_attention(q2, k2, v, attn_mask=attn_mask, dropout_p=dropout_p)
536+
537+
x_attn = attn1 - lambda_full * attn2
538+
else:
539+
q = q * self.scale
540+
attn = q @ k.transpose(-2, -1)
541+
attn = maybe_add_mask(attn, attn_mask)
542+
attn = attn.softmax(dim=-1)
543+
attn = self.attn_drop(attn)
544+
545+
attn = attn.view(B, self.num_heads, 2, N, N)
546+
attn = attn[:, :, 0] - lambda_full * attn[:, :, 1]
547+
x_attn = attn @ v
548+
549+
x_attn = self.sub_norm(x_attn)
550+
x_attn = x_attn * (1 - self.lambda_init)
551+
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
552+
553+
# MLP activation & dropout
554+
x_mlp = self.mlp_act(x_mlp)
555+
x_mlp = self.mlp_drop(x_mlp)
556+
557+
# Fused output projection
558+
y = self.out_proj(torch.cat((x_attn, x_mlp), dim=-1))
559+
560+
# Add residual w/ drop path & layer scale applied
561+
x = x + self.drop_path(self.ls(y))
384562
return x
385563

386564

@@ -2528,6 +2706,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
25282706
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
25292707
hf_hub_id='timm/',
25302708
input_size=(3, 256, 256), crop_pct=0.95),
2709+
'vit_dpwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
2710+
#hf_hub_id='timm/',
2711+
input_size=(3, 256, 256), crop_pct=0.95),
25312712
'vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k': _cfg(
25322713
hf_hub_id='timm/',
25332714
input_size=(3, 256, 256), crop_pct=0.95),
@@ -4038,6 +4219,16 @@ def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionT
40384219
return model
40394220

40404221

4222+
@register_model
4223+
def vit_dpwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
4224+
model_args = dict(
4225+
patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
4226+
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=DiffParallelScalingBlock,
4227+
)
4228+
model = _create_vision_transformer(
4229+
'vit_dpwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
4230+
return model
4231+
40414232
@register_model
40424233
def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
40434234
model_args = dict(

0 commit comments

Comments
 (0)