1-
21from __future__ import annotations
32import math
43from typing import Optional , Tuple
54
65import numpy as np
76import pandas as pd
87
8+
99def _softmax (x : np .ndarray , axis : int = - 1 ) -> np .ndarray :
1010 x_max = np .max (x , axis = axis , keepdims = True )
1111 e = np .exp (x - x_max )
@@ -18,6 +18,7 @@ def _stable_div(x: np.ndarray, denom: np.ndarray) -> np.ndarray:
1818
1919# Time2Vec
2020
21+
2122class Time2Vec :
2223 """
2324 Time2Vec positional encoding (simple) for real-valued time steps.
@@ -51,8 +52,15 @@ def forward(self, time_steps: np.ndarray) -> np.ndarray:
5152
5253# PositionwiseFeedForward
5354
55+
5456class PositionwiseFeedForward :
55- def __init__ (self , d_model : int , hidden : int , drop_prob : float = 0.0 , seed : Optional [int ] = None ):
57+ def __init__ (
58+ self ,
59+ d_model : int ,
60+ hidden : int ,
61+ drop_prob : float = 0.0 ,
62+ seed : Optional [int ] = None ,
63+ ):
5664 if seed is not None :
5765 np .random .seed (seed )
5866 # simple linear layers (no dropout during forward-only inference, but kept shape)
@@ -70,11 +78,17 @@ def forward(self, x: np.ndarray) -> np.ndarray:
7078 return out
7179
7280
73-
7481# Scaled Dot-Product Attention
7582
83+
7684class ScaledDotProductAttention :
77- def forward (self , q : np .ndarray , k : np .ndarray , v : np .ndarray , mask : Optional [np .ndarray ] = None ) -> Tuple [np .ndarray , np .ndarray ]:
85+ def forward (
86+ self ,
87+ q : np .ndarray ,
88+ k : np .ndarray ,
89+ v : np .ndarray ,
90+ mask : Optional [np .ndarray ] = None ,
91+ ) -> Tuple [np .ndarray , np .ndarray ]:
7892 """
7993 q,k,v: shapes (b, n_head, seq_len, d_k)
8094 mask: optional boolean or 0/1 mask of shape (b, seq_len) or (b, 1, 1, seq_len)
@@ -90,7 +104,11 @@ def forward(self, q: np.ndarray, k: np.ndarray, v: np.ndarray, mask: Optional[np
90104 mask2 = mask [:, None , None , :] # (b,1,1,seq_len)
91105 elif mask .ndim == 3 :
92106 # if provided as (b, n_head, seq_len) or (b, 1, seq_len)
93- mask2 = mask [:, None , :, :] if mask .shape [1 ] != seq_len else mask [:, None , None , :]
107+ mask2 = (
108+ mask [:, None , :, :]
109+ if mask .shape [1 ] != seq_len
110+ else mask [:, None , None , :]
111+ )
94112 else :
95113 mask2 = mask
96114 # mask2==0 => masked
@@ -103,6 +121,7 @@ def forward(self, q: np.ndarray, k: np.ndarray, v: np.ndarray, mask: Optional[np
103121
104122# MultiHeadAttention
105123
124+
106125class MultiHeadAttention :
107126 def __init__ (self , d_model : int , n_head : int , seed : Optional [int ] = None ):
108127 if d_model % n_head != 0 :
@@ -114,13 +133,21 @@ def __init__(self, d_model: int, n_head: int, seed: Optional[int] = None):
114133 self .d_k = d_model // n_head
115134
116135 # weight matrices for q,k,v and output
117- self .w_q = np .random .randn (d_model , d_model ) * math .sqrt (2.0 / (d_model + d_model ))
136+ self .w_q = np .random .randn (d_model , d_model ) * math .sqrt (
137+ 2.0 / (d_model + d_model )
138+ )
118139 self .b_q = np .zeros ((d_model ,))
119- self .w_k = np .random .randn (d_model , d_model ) * math .sqrt (2.0 / (d_model + d_model ))
140+ self .w_k = np .random .randn (d_model , d_model ) * math .sqrt (
141+ 2.0 / (d_model + d_model )
142+ )
120143 self .b_k = np .zeros ((d_model ,))
121- self .w_v = np .random .randn (d_model , d_model ) * math .sqrt (2.0 / (d_model + d_model ))
144+ self .w_v = np .random .randn (d_model , d_model ) * math .sqrt (
145+ 2.0 / (d_model + d_model )
146+ )
122147 self .b_v = np .zeros ((d_model ,))
123- self .w_out = np .random .randn (d_model , d_model ) * math .sqrt (2.0 / (d_model + d_model ))
148+ self .w_out = np .random .randn (d_model , d_model ) * math .sqrt (
149+ 2.0 / (d_model + d_model )
150+ )
124151 self .b_out = np .zeros ((d_model ,))
125152
126153 self .attn = ScaledDotProductAttention ()
@@ -139,7 +166,13 @@ def _concat_heads(self, x: np.ndarray) -> np.ndarray:
139166 b , n_head , seq_len , d_k = x .shape
140167 return x .transpose (0 , 2 , 1 , 3 ).reshape (b , seq_len , n_head * d_k )
141168
142- def forward (self , query : np .ndarray , key : np .ndarray , value : np .ndarray , mask : Optional [np .ndarray ] = None ) -> Tuple [np .ndarray , np .ndarray ]:
169+ def forward (
170+ self ,
171+ query : np .ndarray ,
172+ key : np .ndarray ,
173+ value : np .ndarray ,
174+ mask : Optional [np .ndarray ] = None ,
175+ ) -> Tuple [np .ndarray , np .ndarray ]:
143176 """
144177 query/key/value: (b, seq_len, d_model)
145178 returns: out (b, seq_len, d_model), attn_weights (b, n_head, seq_len, seq_len)
@@ -157,9 +190,9 @@ def forward(self, query: np.ndarray, key: np.ndarray, value: np.ndarray, mask: O
157190 return out , attn
158191
159192
160-
161193# LayerNorm
162194
195+
163196class LayerNorm :
164197 def __init__ (self , d_model : int , eps : float = 1e-12 ):
165198 self .gamma = np .ones ((d_model ,))
@@ -173,10 +206,14 @@ def forward(self, x: np.ndarray) -> np.ndarray:
173206 x_norm = (x - mean ) / np .sqrt (var + self .eps )
174207 return self .gamma * x_norm + self .beta
175208
209+
176210# TransformerEncoderLayer
177211
212+
178213class TransformerEncoderLayer :
179- def __init__ (self , d_model : int , n_head : int , hidden_dim : int , seed : Optional [int ] = None ):
214+ def __init__ (
215+ self , d_model : int , n_head : int , hidden_dim : int , seed : Optional [int ] = None
216+ ):
180217 self .self_attn = MultiHeadAttention (d_model , n_head , seed = seed )
181218 self .ffn = PositionwiseFeedForward (d_model , hidden_dim , seed = seed )
182219 self .norm1 = LayerNorm (d_model )
@@ -193,26 +230,41 @@ def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarra
193230
194231# TransformerEncoder (stack)
195232
233+
196234class TransformerEncoder :
197- def __init__ (self , d_model : int , n_head : int , hidden_dim : int , num_layers : int , seed : Optional [int ] = None ):
198- self .layers = [TransformerEncoderLayer (d_model , n_head , hidden_dim , seed = seed ) for _ in range (num_layers )]
235+ def __init__ (
236+ self ,
237+ d_model : int ,
238+ n_head : int ,
239+ hidden_dim : int ,
240+ num_layers : int ,
241+ seed : Optional [int ] = None ,
242+ ):
243+ self .layers = [
244+ TransformerEncoderLayer (d_model , n_head , hidden_dim , seed = seed )
245+ for _ in range (num_layers )
246+ ]
199247
200248 def forward (self , x : np .ndarray , mask : Optional [np .ndarray ] = None ) -> np .ndarray :
201249 out = x
202250 for layer in self .layers :
203251 out = layer .forward (out , mask )
204252 return out
205253
254+
206255# AttentionPooling
207256
257+
208258class AttentionPooling :
209259 def __init__ (self , d_model : int , seed : Optional [int ] = None ):
210260 if seed is not None :
211261 np .random .seed (seed )
212262 self .w = np .random .randn (d_model ) * math .sqrt (2.0 / d_model )
213263 self .b = 0.0
214264
215- def forward (self , x : np .ndarray , mask : Optional [np .ndarray ] = None ) -> Tuple [np .ndarray , np .ndarray ]:
265+ def forward (
266+ self , x : np .ndarray , mask : Optional [np .ndarray ] = None
267+ ) -> Tuple [np .ndarray , np .ndarray ]:
216268 """
217269 x: (b, seq_len, d_model)
218270 mask: (b, seq_len) where 1 = valid, 0 = pad
@@ -228,8 +280,10 @@ def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.
228280 pooled = np .matmul (weights [:, None , :], x ).squeeze (1 ) # (b, d_model)
229281 return pooled , weights
230282
283+
231284# EEGTransformer (forward-only)
232285
286+
233287class EEGTransformer :
234288 def __init__ (
235289 self ,
@@ -248,21 +302,29 @@ def __init__(
248302 self .d_model = d_model
249303 self .task_type = task_type
250304 # input projection
251- self .w_in = np .random .randn (feature_dim , d_model ) * math .sqrt (2.0 / (feature_dim + d_model ))
305+ self .w_in = np .random .randn (feature_dim , d_model ) * math .sqrt (
306+ 2.0 / (feature_dim + d_model )
307+ )
252308 self .b_in = np .zeros ((d_model ,))
253309 # time embedding
254310 self .time2vec = Time2Vec (d_model , seed = seed )
255- self .encoder = TransformerEncoder (d_model , n_head , hidden_dim , num_layers , seed = seed )
311+ self .encoder = TransformerEncoder (
312+ d_model , n_head , hidden_dim , num_layers , seed = seed
313+ )
256314 self .pooling = AttentionPooling (d_model , seed = seed )
257315 # output
258- self .w_out = np .random .randn (d_model , output_dim ) * math .sqrt (2.0 / (d_model + output_dim ))
316+ self .w_out = np .random .randn (d_model , output_dim ) * math .sqrt (
317+ 2.0 / (d_model + output_dim )
318+ )
259319 self .b_out = np .zeros ((output_dim ,))
260320
261321 def _input_proj (self , x : np .ndarray ) -> np .ndarray :
262322 # x: (b, seq_len, feature_dim) -> (b, seq_len, d_model)
263323 return np .tensordot (x , self .w_in , axes = ([2 ], [0 ])) + self .b_in
264324
265- def forward (self , x : np .ndarray , mask : Optional [np .ndarray ] = None ) -> Tuple [np .ndarray , np .ndarray ]:
325+ def forward (
326+ self , x : np .ndarray , mask : Optional [np .ndarray ] = None
327+ ) -> Tuple [np .ndarray , np .ndarray ]:
266328 """
267329 x: (b, seq_len, feature_dim)
268330 mask: optional (b, seq_len) 1=valid,0=pad
@@ -276,7 +338,9 @@ def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.
276338 x_proj = self ._input_proj (x ) + time_emb # broadcast add -> (b,t,d_model)
277339 enc = self .encoder .forward (x_proj , mask )
278340 pooled , attn_weights = self .pooling .forward (enc , mask )
279- out = np .tensordot (pooled , self .w_out , axes = ([1 ], [0 ])) + self .b_out # (b,output_dim)
341+ out = (
342+ np .tensordot (pooled , self .w_out , axes = ([1 ], [0 ])) + self .b_out
343+ ) # (b,output_dim)
280344 if self .task_type == "classification" :
281345 out = _softmax (out , axis = - 1 )
282346 return out , attn_weights
@@ -292,7 +356,15 @@ def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.
292356 rng = np .random .RandomState (42 )
293357 X = rng .randn (batch , seq_len , feature_dim ).astype (float )
294358
295- model = EEGTransformer (feature_dim = feature_dim , d_model = 32 , n_head = 4 , hidden_dim = 64 , num_layers = 2 , output_dim = 1 , seed = 0 )
359+ model = EEGTransformer (
360+ feature_dim = feature_dim ,
361+ d_model = 32 ,
362+ n_head = 4 ,
363+ hidden_dim = 64 ,
364+ num_layers = 2 ,
365+ output_dim = 1 ,
366+ seed = 0 ,
367+ )
296368 out , attn_weights = model .forward (X )
297369 print ("Output shape:" , out .shape )
298370 print ("Output:" , out )
0 commit comments