@@ -303,8 +303,9 @@ def __init__(
303303 mlp_layer : Optional [Type [nn .Module ]] = None , # not used
304304 attn_layer : Optional [LayerType ] = None , # not used
305305 depth : int = 0 , # not used
306- device = None ,
307- dtype = None ,
306+ fuse_out_proj : bool = False ,
307+ device = None ,
308+ dtype = None ,
308309 ) -> None :
309310 super ().__init__ ()
310311 dd = {'device' : device , 'dtype' : dtype }
@@ -330,11 +331,20 @@ def __init__(
330331 self .q_norm = norm_layer (self .head_dim , ** dd ) if qk_norm else nn .Identity ()
331332 self .k_norm = norm_layer (self .head_dim , ** dd ) if qk_norm else nn .Identity ()
332333 self .attn_drop = nn .Dropout (attn_drop )
333- self .attn_out_proj = nn .Linear (dim , dim , bias = proj_bias , ** dd )
334334
335335 self .mlp_drop = nn .Dropout (proj_drop )
336336 self .mlp_act = act_layer ()
337- self .mlp_out_proj = nn .Linear (mlp_hidden_dim , dim , bias = proj_bias , ** dd )
337+
338+ if fuse_out_proj :
339+ # Fused output projection for both attention and MLP
340+ self .out_proj = nn .Linear (dim + mlp_hidden_dim , dim , bias = proj_bias , ** dd )
341+ self .attn_out_proj = None
342+ self .mlp_out_proj = None
343+ else :
344+ # Separate output projections
345+ self .out_proj = None
346+ self .attn_out_proj = nn .Linear (dim , dim , bias = proj_bias , ** dd )
347+ self .mlp_out_proj = nn .Linear (mlp_hidden_dim , dim , bias = proj_bias , ** dd )
338348
339349 self .ls = LayerScale (dim , init_values = init_values , ** dd ) if init_values is not None else nn .Identity ()
340350 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
@@ -371,16 +381,184 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) ->
371381 x_attn = attn @ v
372382
373383 x_attn = x_attn .transpose (1 , 2 ).reshape (B , N , C )
374- x_attn = self .attn_out_proj (x_attn )
375384
376- # MLP activation, dropout, fc2
385+ # MLP activation & dropout
377386 x_mlp = self .mlp_act (x_mlp )
378387 x_mlp = self .mlp_drop (x_mlp )
379- x_mlp = self .mlp_out_proj (x_mlp )
388+
389+ # Output projection (fused or separate)
390+ if self .out_proj is not None :
391+ y = self .out_proj (torch .cat ((x_attn , x_mlp ), dim = - 1 ))
392+ else :
393+ y = self .attn_out_proj (x_attn ) + self .mlp_out_proj (x_mlp )
380394
381395 # Add residual w/ drop path & layer scale applied
382- y = self .drop_path (self .ls (x_attn + x_mlp ))
383- x = x + y
396+ x = x + self .drop_path (self .ls (y ))
397+ return x
398+
399+
400+ class DiffParallelScalingBlock (nn .Module ):
401+ """ Parallel ViT block with Differential Attention (MLP & Attention in parallel).
402+
403+ Combines the parallel MLP+Attention structure from 'Scaling Vision Transformers to
404+ 22 Billion Parameters' (https://arxiv.org/abs/2302.05442) with differential attention
405+ from 'Differential Transformer' (https://arxiv.org/abs/2410.05258).
406+ """
407+ fused_attn : Final [bool ]
408+
409+ def __init__ (
410+ self ,
411+ dim : int ,
412+ num_heads : int ,
413+ mlp_ratio : float = 4. ,
414+ qkv_bias : bool = False ,
415+ qk_norm : bool = False ,
416+ scale_attn_norm : bool = False ,
417+ scale_mlp_norm : bool = False ,
418+ proj_bias : bool = True ,
419+ proj_drop : float = 0. ,
420+ attn_drop : float = 0. ,
421+ init_values : Optional [float ] = None ,
422+ drop_path : float = 0. ,
423+ act_layer : Type [nn .Module ] = nn .GELU ,
424+ norm_layer : Type [nn .Module ] = LayerNorm ,
425+ mlp_layer : Optional [Type [nn .Module ]] = None ,
426+ attn_layer : Optional [LayerType ] = None ,
427+ depth : int = 0 ,
428+ dual_lambda : bool = False ,
429+ device = None ,
430+ dtype = None ,
431+ ) -> None :
432+ super ().__init__ ()
433+ dd = {'device' : device , 'dtype' : dtype }
434+ assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
435+ assert not scale_attn_norm and not scale_mlp_norm , 'Scale norms not supported'
436+ self .num_heads = num_heads
437+ self .head_dim = dim // num_heads // 2 # Half head_dim for diff attention
438+ self .scale = self .head_dim ** - 0.5
439+ self .fused_attn = use_fused_attn ()
440+ mlp_hidden_dim = int (mlp_ratio * dim )
441+ in_proj_out_dim = mlp_hidden_dim + 3 * dim
442+
443+ self .in_norm = norm_layer (dim , ** dd )
444+ self .in_proj = nn .Linear (dim , in_proj_out_dim , bias = qkv_bias , ** dd )
445+ self .in_split = [mlp_hidden_dim ] + [dim ] * 3
446+ if qkv_bias :
447+ self .register_buffer ('qkv_bias' , None )
448+ self .register_parameter ('mlp_bias' , None )
449+ else :
450+ self .register_buffer ('qkv_bias' , torch .zeros (3 * dim , ** dd ), persistent = False )
451+ self .mlp_bias = nn .Parameter (torch .zeros (mlp_hidden_dim , ** dd ))
452+
453+ self .q_norm = norm_layer (self .head_dim , ** dd ) if qk_norm else nn .Identity ()
454+ self .k_norm = norm_layer (self .head_dim , ** dd ) if qk_norm else nn .Identity ()
455+ self .attn_drop = nn .Dropout (attn_drop )
456+ self .attn_drop_p = attn_drop
457+
458+ # Differential attention specific
459+ self .sub_norm = RmsNorm (2 * self .head_dim , eps = 1e-5 , ** dd )
460+ self .dual_lambda = dual_lambda
461+ if dual_lambda :
462+ self .lambda_a = nn .Parameter (torch .empty ((), dtype = torch .float32 , device = device ))
463+ self .lambda_b = nn .Parameter (torch .empty ((), dtype = torch .float32 , device = device ))
464+ self .lambda_q1 = self .lambda_k1 = self .lambda_q2 = self .lambda_k2 = None
465+ else :
466+ self .lambda_a = self .lambda_b = None
467+ self .lambda_q1 = nn .Parameter (torch .empty (self .head_dim , dtype = torch .float32 , device = device ))
468+ self .lambda_k1 = nn .Parameter (torch .empty (self .head_dim , dtype = torch .float32 , device = device ))
469+ self .lambda_q2 = nn .Parameter (torch .empty (self .head_dim , dtype = torch .float32 , device = device ))
470+ self .lambda_k2 = nn .Parameter (torch .empty (self .head_dim , dtype = torch .float32 , device = device ))
471+
472+ self .mlp_drop = nn .Dropout (proj_drop )
473+ self .mlp_act = act_layer ()
474+
475+ # Fused output projection for both attention and MLP
476+ self .out_proj = nn .Linear (dim + mlp_hidden_dim , dim , bias = proj_bias , ** dd )
477+
478+ self .ls = LayerScale (dim , init_values = init_values , ** dd ) if init_values is not None else nn .Identity ()
479+ self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
480+
481+ self .lambda_init = 0.8
482+ self .set_lambda_init (depth )
483+ self .reset_parameters ()
484+
485+ def set_lambda_init (self , depth : int ):
486+ self .lambda_init = 0.8 - 0.6 * math .exp (- 0.3 * depth )
487+
488+ def reset_parameters (self ):
489+ if self .dual_lambda :
490+ nn .init .zeros_ (self .lambda_a )
491+ nn .init .zeros_ (self .lambda_b )
492+ else :
493+ nn .init .normal_ (self .lambda_q1 , mean = 0 , std = 0.1 )
494+ nn .init .normal_ (self .lambda_k1 , mean = 0 , std = 0.1 )
495+ nn .init .normal_ (self .lambda_q2 , mean = 0 , std = 0.1 )
496+ nn .init .normal_ (self .lambda_k2 , mean = 0 , std = 0.1 )
497+
498+ def _compute_lambda (self ) -> torch .Tensor :
499+ if self .lambda_a is not None :
500+ lambda_1 = torch .exp (self .lambda_a )
501+ lambda_2 = torch .exp (self .lambda_b )
502+ else :
503+ lambda_1 = torch .exp (torch .sum (self .lambda_q1 * self .lambda_k1 , dim = - 1 ).float ())
504+ lambda_2 = torch .exp (torch .sum (self .lambda_q2 * self .lambda_k2 , dim = - 1 ).float ())
505+ return lambda_1 - lambda_2 + self .lambda_init
506+
507+ def forward (self , x : torch .Tensor , attn_mask : Optional [torch .Tensor ] = None ) -> torch .Tensor :
508+ B , N , C = x .shape
509+
510+ # Combined MLP fc1 & qkv projections
511+ y = self .in_norm (x )
512+ if self .mlp_bias is not None :
513+ y = F .linear (y , self .in_proj .weight , torch .cat ((self .qkv_bias , self .mlp_bias )))
514+ else :
515+ y = self .in_proj (y )
516+ x_mlp , q , k , v = torch .split (y , self .in_split , dim = - 1 )
517+
518+ # Reshape for differential attention (2x heads with half head_dim for q/k)
519+ q = q .reshape (B , N , 2 * self .num_heads , self .head_dim ).transpose (1 , 2 )
520+ k = k .reshape (B , N , 2 * self .num_heads , self .head_dim ).transpose (1 , 2 )
521+ v = v .reshape (B , N , self .num_heads , 2 * self .head_dim ).transpose (1 , 2 )
522+
523+ q , k = self .q_norm (q ), self .k_norm (k )
524+
525+ lambda_full = self ._compute_lambda ().type_as (q )
526+
527+ if self .fused_attn :
528+ q = q .reshape (B , self .num_heads , 2 , N , self .head_dim )
529+ k = k .reshape (B , self .num_heads , 2 , N , self .head_dim )
530+ q1 , q2 = q .unbind (2 )
531+ k1 , k2 = k .unbind (2 )
532+
533+ dropout_p = self .attn_drop_p if self .training else 0.0
534+ attn1 = F .scaled_dot_product_attention (q1 , k1 , v , attn_mask = attn_mask , dropout_p = dropout_p )
535+ attn2 = F .scaled_dot_product_attention (q2 , k2 , v , attn_mask = attn_mask , dropout_p = dropout_p )
536+
537+ x_attn = attn1 - lambda_full * attn2
538+ else :
539+ q = q * self .scale
540+ attn = q @ k .transpose (- 2 , - 1 )
541+ attn = maybe_add_mask (attn , attn_mask )
542+ attn = attn .softmax (dim = - 1 )
543+ attn = self .attn_drop (attn )
544+
545+ attn = attn .view (B , self .num_heads , 2 , N , N )
546+ attn = attn [:, :, 0 ] - lambda_full * attn [:, :, 1 ]
547+ x_attn = attn @ v
548+
549+ x_attn = self .sub_norm (x_attn )
550+ x_attn = x_attn * (1 - self .lambda_init )
551+ x_attn = x_attn .transpose (1 , 2 ).reshape (B , N , C )
552+
553+ # MLP activation & dropout
554+ x_mlp = self .mlp_act (x_mlp )
555+ x_mlp = self .mlp_drop (x_mlp )
556+
557+ # Fused output projection
558+ y = self .out_proj (torch .cat ((x_attn , x_mlp ), dim = - 1 ))
559+
560+ # Add residual w/ drop path & layer scale applied
561+ x = x + self .drop_path (self .ls (y ))
384562 return x
385563
386564
@@ -2528,6 +2706,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
25282706 'vit_pwee_patch16_reg1_gap_256.sbb_in1k' : _cfg (
25292707 hf_hub_id = 'timm/' ,
25302708 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
2709+ 'vit_dpwee_patch16_reg1_gap_256.sbb_in1k' : _cfg (
2710+ #hf_hub_id='timm/',
2711+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
25312712 'vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k' : _cfg (
25322713 hf_hub_id = 'timm/' ,
25332714 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
@@ -4038,6 +4219,16 @@ def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionT
40384219 return model
40394220
40404221
4222+ @register_model
4223+ def vit_dpwee_patch16_reg1_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
4224+ model_args = dict (
4225+ patch_size = 16 , embed_dim = 256 , depth = 16 , num_heads = 4 , init_values = 1e-5 , mlp_ratio = 5 ,
4226+ class_token = False , no_embed_class = True , reg_tokens = 1 , global_pool = 'avg' , block_fn = DiffParallelScalingBlock ,
4227+ )
4228+ model = _create_vision_transformer (
4229+ 'vit_dpwee_patch16_reg1_gap_256' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
4230+ return model
4231+
40414232@register_model
40424233def vit_little_patch16_reg1_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
40434234 model_args = dict (
0 commit comments