@@ -20,6 +20,7 @@ def _stable_div(numerator: np.ndarray, denominator: np.ndarray) -> np.ndarray:
2020# 🔹 Time2Vec
2121# -------------------------------
2222
23+
2324class Time2Vec :
2425 def __init__ (self , d_model : int , seed : Optional [int ] = None ) -> None :
2526 if d_model < 2 :
@@ -63,12 +64,23 @@ def forward(self, time_indices: np.ndarray) -> np.ndarray:
6364# 🔹 Positionwise FeedForward
6465# -------------------------------
6566
67+
6668class PositionwiseFeedForward :
67- def __init__ (self , d_model : int , hidden_dim : int , drop_prob : float = 0.0 , seed : Optional [int ] = None ) -> None :
69+ def __init__ (
70+ self ,
71+ d_model : int ,
72+ hidden_dim : int ,
73+ drop_prob : float = 0.0 ,
74+ seed : Optional [int ] = None ,
75+ ) -> None :
6876 self .rng = np .random .default_rng (seed )
69- self .w1 : np .ndarray = self .rng .standard_normal ((d_model , hidden_dim )) * math .sqrt (2.0 / (d_model + hidden_dim ))
77+ self .w1 : np .ndarray = self .rng .standard_normal (
78+ (d_model , hidden_dim )
79+ ) * math .sqrt (2.0 / (d_model + hidden_dim ))
7080 self .b1 : np .ndarray = np .zeros ((hidden_dim ,))
71- self .w2 : np .ndarray = self .rng .standard_normal ((hidden_dim , d_model )) * math .sqrt (2.0 / (hidden_dim + d_model ))
81+ self .w2 : np .ndarray = self .rng .standard_normal (
82+ (hidden_dim , d_model )
83+ ) * math .sqrt (2.0 / (hidden_dim + d_model ))
7284 self .b2 : np .ndarray = np .zeros ((d_model ,))
7385
7486 def forward (self , input_tensor : np .ndarray ) -> np .ndarray :
@@ -82,6 +94,7 @@ def forward(self, input_tensor: np.ndarray) -> np.ndarray:
8294# 🔹 Scaled Dot-Product Attention
8395# -------------------------------
8496
97+
8598class ScaledDotProductAttention :
8699 def forward (
87100 self ,
@@ -97,7 +110,11 @@ def forward(
97110 if mask .ndim == 2 :
98111 mask_reshaped = mask [:, None , None , :]
99112 elif mask .ndim == 3 :
100- mask_reshaped = mask [:, None , :, :] if mask .shape [1 ] != seq_len else mask [:, None , None , :]
113+ mask_reshaped = (
114+ mask [:, None , :, :]
115+ if mask .shape [1 ] != seq_len
116+ else mask [:, None , None , :]
117+ )
101118 else :
102119 mask_reshaped = mask
103120 scores = np .where (mask_reshaped == 0 , - 1e9 , scores )
@@ -111,6 +128,7 @@ def forward(
111128# 🔹 Multi-Head Attention
112129# -------------------------------
113130
131+
114132class MultiHeadAttention :
115133 def __init__ (self , d_model : int , n_head : int , seed : Optional [int ] = None ) -> None :
116134 if d_model % n_head != 0 :
@@ -121,27 +139,41 @@ def __init__(self, d_model: int, n_head: int, seed: Optional[int] = None) -> Non
121139 self .n_head = n_head
122140 self .d_k = d_model // n_head
123141
124- self .w_q = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / (d_model + d_model ))
142+ self .w_q = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
143+ 2.0 / (d_model + d_model )
144+ )
125145 self .b_q = np .zeros ((d_model ,))
126- self .w_k = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / (d_model + d_model ))
146+ self .w_k = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
147+ 2.0 / (d_model + d_model )
148+ )
127149 self .b_k = np .zeros ((d_model ,))
128- self .w_v = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / (d_model + d_model ))
150+ self .w_v = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
151+ 2.0 / (d_model + d_model )
152+ )
129153 self .b_v = np .zeros ((d_model ,))
130- self .w_out = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / (d_model + d_model ))
154+ self .w_out = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
155+ 2.0 / (d_model + d_model )
156+ )
131157 self .b_out = np .zeros ((d_model ,))
132158
133159 self .attn = ScaledDotProductAttention ()
134160
135- def _linear (self , input_tensor : np .ndarray , weight : np .ndarray , bias : np .ndarray ) -> np .ndarray :
161+ def _linear (
162+ self , input_tensor : np .ndarray , weight : np .ndarray , bias : np .ndarray
163+ ) -> np .ndarray :
136164 return np .tensordot (input_tensor , weight , axes = ([2 ], [0 ])) + bias
137165
138166 def _split_heads (self , input_tensor : np .ndarray ) -> np .ndarray :
139167 batch_size , seq_len , _ = input_tensor .shape
140- return input_tensor .reshape (batch_size , seq_len , self .n_head , self .d_k ).transpose (0 , 2 , 1 , 3 )
168+ return input_tensor .reshape (
169+ batch_size , seq_len , self .n_head , self .d_k
170+ ).transpose (0 , 2 , 1 , 3 )
141171
142172 def _concat_heads (self , input_tensor : np .ndarray ) -> np .ndarray :
143173 batch_size , n_head , seq_len , d_k = input_tensor .shape
144- return input_tensor .transpose (0 , 2 , 1 , 3 ).reshape (batch_size , seq_len , n_head * d_k )
174+ return input_tensor .transpose (0 , 2 , 1 , 3 ).reshape (
175+ batch_size , seq_len , n_head * d_k
176+ )
145177
146178 def forward (
147179 self ,
@@ -174,6 +206,7 @@ def forward(
174206# 🔹 LayerNorm
175207# -------------------------------
176208
209+
177210class LayerNorm :
178211 def __init__ (self , d_model : int , eps : float = 1e-12 ) -> None :
179212 self .gamma : np .ndarray = np .ones ((d_model ,))
@@ -185,18 +218,25 @@ def forward(self, input_tensor: np.ndarray) -> np.ndarray:
185218 var = np .mean ((input_tensor - mean ) ** 2 , axis = - 1 , keepdims = True )
186219 normalized_tensor = (input_tensor - mean ) / np .sqrt (var + self .eps )
187220 return self .gamma * normalized_tensor + self .beta
221+
222+
188223# -------------------------------
189224# 🔹 Transformer Encoder Layer
190225# -------------------------------
191226
227+
192228class TransformerEncoderLayer :
193- def __init__ (self , d_model : int , n_head : int , hidden_dim : int , seed : Optional [int ] = None ) -> None :
229+ def __init__ (
230+ self , d_model : int , n_head : int , hidden_dim : int , seed : Optional [int ] = None
231+ ) -> None :
194232 self .self_attn = MultiHeadAttention (d_model , n_head , seed = seed )
195233 self .ffn = PositionwiseFeedForward (d_model , hidden_dim , seed = seed )
196234 self .norm1 = LayerNorm (d_model )
197235 self .norm2 = LayerNorm (d_model )
198236
199- def forward (self , encoded_input : np .ndarray , mask : Optional [np .ndarray ] = None ) -> np .ndarray :
237+ def forward (
238+ self , encoded_input : np .ndarray , mask : Optional [np .ndarray ] = None
239+ ) -> np .ndarray :
200240 """
201241 Forward pass for one encoder layer.
202242
@@ -220,7 +260,9 @@ def forward(self, encoded_input: np.ndarray, mask: Optional[np.ndarray] = None)
220260 >>> out.shape
221261 (1, 3, 4)
222262 """
223- attn_output , _ = self .self_attn .forward (encoded_input , encoded_input , encoded_input , mask )
263+ attn_output , _ = self .self_attn .forward (
264+ encoded_input , encoded_input , encoded_input , mask
265+ )
224266 out1 = self .norm1 .forward (encoded_input + attn_output )
225267 ffn_output = self .ffn .forward (out1 )
226268 out2 = self .norm2 .forward (out1 + ffn_output )
@@ -231,11 +273,24 @@ def forward(self, encoded_input: np.ndarray, mask: Optional[np.ndarray] = None)
231273# 🔹 Transformer Encoder Stack
232274# -------------------------------
233275
276+
234277class TransformerEncoder :
235- def __init__ (self , d_model : int , n_head : int , hidden_dim : int , num_layers : int , seed : Optional [int ] = None ) -> None :
236- self .layers = [TransformerEncoderLayer (d_model , n_head , hidden_dim , seed = seed ) for _ in range (num_layers )]
278+ def __init__ (
279+ self ,
280+ d_model : int ,
281+ n_head : int ,
282+ hidden_dim : int ,
283+ num_layers : int ,
284+ seed : Optional [int ] = None ,
285+ ) -> None :
286+ self .layers = [
287+ TransformerEncoderLayer (d_model , n_head , hidden_dim , seed = seed )
288+ for _ in range (num_layers )
289+ ]
237290
238- def forward (self , encoded_input : np .ndarray , mask : Optional [np .ndarray ] = None ) -> np .ndarray :
291+ def forward (
292+ self , encoded_input : np .ndarray , mask : Optional [np .ndarray ] = None
293+ ) -> np .ndarray :
239294 """
240295 Forward pass for encoder stack.
241296
@@ -269,13 +324,18 @@ def forward(self, encoded_input: np.ndarray, mask: Optional[np.ndarray] = None)
269324# 🔹 Attention Pooling
270325# -------------------------------
271326
327+
272328class AttentionPooling :
273329 def __init__ (self , d_model : int , seed : Optional [int ] = None ) -> None :
274330 self .rng = np .random .default_rng (seed )
275- self .w : np .ndarray = self .rng .standard_normal (d_model ) * math .sqrt (2.0 / d_model )
331+ self .w : np .ndarray = self .rng .standard_normal (d_model ) * math .sqrt (
332+ 2.0 / d_model
333+ )
276334 self .b : float = 0.0
277335
278- def forward (self , encoded_features : np .ndarray , mask : Optional [np .ndarray ] = None ) -> tuple [np .ndarray , np .ndarray ]:
336+ def forward (
337+ self , encoded_features : np .ndarray , mask : Optional [np .ndarray ] = None
338+ ) -> tuple [np .ndarray , np .ndarray ]:
279339 """
280340 Attention-based pooling.
281341
@@ -315,6 +375,7 @@ def forward(self, encoded_features: np.ndarray, mask: Optional[np.ndarray] = Non
315375# 🔹 EEG Transformer
316376# -------------------------------
317377
378+
318379class EEGTransformer :
319380 def __init__ (
320381 self ,
@@ -332,20 +393,28 @@ def __init__(
332393 self .d_model = d_model
333394 self .task_type = task_type
334395
335- self .w_in : np .ndarray = self .rng .standard_normal ((feature_dim , d_model )) * math .sqrt (2.0 / (feature_dim + d_model ))
396+ self .w_in : np .ndarray = self .rng .standard_normal (
397+ (feature_dim , d_model )
398+ ) * math .sqrt (2.0 / (feature_dim + d_model ))
336399 self .b_in : np .ndarray = np .zeros ((d_model ,))
337400
338401 self .time2vec = Time2Vec (d_model , seed = seed )
339- self .encoder = TransformerEncoder (d_model , n_head , hidden_dim , num_layers , seed = seed )
402+ self .encoder = TransformerEncoder (
403+ d_model , n_head , hidden_dim , num_layers , seed = seed
404+ )
340405 self .pooling = AttentionPooling (d_model , seed = seed )
341406
342- self .w_out : np .ndarray = self .rng .standard_normal ((d_model , output_dim )) * math .sqrt (2.0 / (d_model + output_dim ))
407+ self .w_out : np .ndarray = self .rng .standard_normal (
408+ (d_model , output_dim )
409+ ) * math .sqrt (2.0 / (d_model + output_dim ))
343410 self .b_out : np .ndarray = np .zeros ((output_dim ,))
344411
345412 def _input_projection (self , input_tensor : np .ndarray ) -> np .ndarray :
346413 return np .tensordot (input_tensor , self .w_in , axes = ([2 ], [0 ])) + self .b_in
347414
348- def forward (self , input_tensor : np .ndarray , mask : Optional [np .ndarray ] = None ) -> tuple [np .ndarray , np .ndarray ]:
415+ def forward (
416+ self , input_tensor : np .ndarray , mask : Optional [np .ndarray ] = None
417+ ) -> tuple [np .ndarray , np .ndarray ]:
349418 """
350419 Forward pass for EEG Transformer.
351420
@@ -383,7 +452,9 @@ def forward(self, input_tensor: np.ndarray, mask: Optional[np.ndarray] = None) -
383452 encoded_features = self .encoder .forward (projected_input , mask )
384453 pooled_output , attention_weights = self .pooling .forward (encoded_features , mask )
385454
386- output_tensor = np .tensordot (pooled_output , self .w_out , axes = ([1 ], [0 ])) + self .b_out
455+ output_tensor = (
456+ np .tensordot (pooled_output , self .w_out , axes = ([1 ], [0 ])) + self .b_out
457+ )
387458 if self .task_type == "classification" :
388459 output_tensor = _softmax (output_tensor , axis = - 1 )
389460
0 commit comments