@@ -38,7 +38,7 @@ def FeedForward(dim, hidden_dim):
3838 )
3939
4040class Attention (Module ):
41- def __init__ (self , dim , heads = 8 , dim_head = 64 ):
41+ def __init__ (self , dim , heads = 8 , dim_head = 64 , learned_value_residual_mix = False ):
4242 super ().__init__ ()
4343 inner_dim = dim_head * heads
4444 self .heads = heads
@@ -50,14 +50,21 @@ def __init__(self, dim, heads = 8, dim_head = 64):
5050 self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
5151 self .to_out = nn .Linear (inner_dim , dim , bias = False )
5252
53+ self .to_residual_mix = nn .Sequential (
54+ nn .Linear (dim , heads ),
55+ nn .Sigmoid (),
56+ Rearrange ('b n h -> b h n 1' )
57+ ) if learned_value_residual_mix else (lambda _ : 0.5 )
58+
5359 def forward (self , x , value_residual = None ):
5460 x = self .norm (x )
5561
5662 qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
5763 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = self .heads ), qkv )
5864
5965 if exists (value_residual ):
60- v = 0.5 * (v + value_residual )
66+ mix = self .to_residual_mix (x )
67+ v = v * mix + value_residual * (1. - mix )
6168
6269 dots = torch .matmul (q , k .transpose (- 1 , - 2 )) * self .scale
6370
@@ -73,9 +80,10 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim):
7380 super ().__init__ ()
7481 self .norm = nn .LayerNorm (dim )
7582 self .layers = ModuleList ([])
76- for _ in range (depth ):
83+ for i in range (depth ):
84+ is_first = i == 0
7785 self .layers .append (ModuleList ([
78- Attention (dim , heads = heads , dim_head = dim_head ),
86+ Attention (dim , heads = heads , dim_head = dim_head , learned_value_residual_mix = not is_first ),
7987 FeedForward (dim , mlp_dim )
8088 ]))
8189 def forward (self , x ):
0 commit comments