22# -------------------------------
33from __future__ import annotations
44import math
5- from typing import Optional , Tuple
6-
75import numpy as np
86
97
@@ -110,26 +108,18 @@ def __init__(self, d_model: int, n_head: int, seed: int | None = None) -> None:
110108 self .n_head = n_head
111109 self .d_k = d_model // n_head
112110 self .rng = np .random .default_rng (seed )
113- self .w_q = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
114- 2.0 / d_model
115- )
116- self .w_k = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
117- 2.0 / d_model
118- )
119- self .w_v = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
120- 2.0 / d_model
121- )
122- self .w_o = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
123- 2.0 / d_model
124- )
111+ self .w_q = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / d_model )
112+ self .w_k = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / d_model )
113+ self .w_v = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / d_model )
114+ self .w_o = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / d_model )
125115
126116 def forward (
127117 self ,
128118 query : np .ndarray ,
129119 key : np .ndarray ,
130120 value : np .ndarray ,
131121 mask : np .ndarray | None = None ,
132- ) -> Tuple [np .ndarray , np .ndarray ]:
122+ ) -> tuple [np .ndarray , np .ndarray ]:
133123 """
134124 >>> attn = MultiHeadAttention(4, 2, seed=0)
135125 >>> x = np.ones((1, 3, 4))
@@ -140,17 +130,20 @@ def forward(
140130 (1, 2, 3, 3)
141131 """
142132 batch_size , _seq_len , _ = query .shape
143- Q = np .tensordot (query , self .w_q , axes = ([2 ], [0 ]))
144- K = np .tensordot (key , self .w_k , axes = ([2 ], [0 ]))
145- V = np .tensordot (value , self .w_v , axes = ([2 ], [0 ]))
146- Q = Q .reshape (batch_size , - 1 , self .n_head , self .d_k ).transpose (0 , 2 , 1 , 3 )
147- K = K .reshape (batch_size , - 1 , self .n_head , self .d_k ).transpose (0 , 2 , 1 , 3 )
148- V = V .reshape (batch_size , - 1 , self .n_head , self .d_k ).transpose (0 , 2 , 1 , 3 )
149- scores = np .matmul (Q , K .transpose (0 , 1 , 3 , 2 )) / math .sqrt (self .d_k )
133+ q = np .tensordot (query , self .w_q , axes = ([2 ], [0 ]))
134+ k = np .tensordot (key , self .w_k , axes = ([2 ], [0 ]))
135+ v = np .tensordot (value , self .w_v , axes = ([2 ], [0 ]))
136+
137+ q = q .reshape (batch_size , - 1 , self .n_head , self .d_k ).transpose (0 , 2 , 1 , 3 )
138+ k = k .reshape (batch_size , - 1 , self .n_head , self .d_k ).transpose (0 , 2 , 1 , 3 )
139+ v = v .reshape (batch_size , - 1 , self .n_head , self .d_k ).transpose (0 , 2 , 1 , 3 )
140+
141+ scores = np .matmul (q , k .transpose (0 , 1 , 3 , 2 )) / math .sqrt (self .d_k )
150142 if mask is not None :
151143 scores = np .where (mask [:, None , None , :] == 0 , - 1e9 , scores )
144+
152145 attn_weights = _softmax (scores , axis = - 1 )
153- out = np .matmul (attn_weights , V )
146+ out = np .matmul (attn_weights , v )
154147 out = out .transpose (0 , 2 , 1 , 3 ).reshape (batch_size , - 1 , self .d_model )
155148 out = np .tensordot (out , self .w_o , axes = ([2 ], [0 ]))
156149 return out , attn_weights
0 commit comments