44from torch import nn
55
66
7- # Time2Vec layer for positional encoding of real-time data like EEG
87class Time2Vec (nn .Module ):
9- # Encodes time steps into a continuous embedding space
10- def __init__ (self , d_model ):
8+ """
9+ Time2Vec layer for positional encoding of real-time data like EEG.
10+
11+ >>> import torch
12+ >>> layer = Time2Vec(4)
13+ >>> t = torch.ones(1, 3, 1)
14+ >>> output = layer.forward(t)
15+ >>> output.shape
16+ torch.Size([1, 3, 4])
17+ """
18+ def __init__ (self , d_model : int ) -> None :
1119 super ().__init__ ()
1220 self .w0 = nn .Parameter (torch .randn (1 , 1 ))
1321 self .b0 = nn .Parameter (torch .randn (1 , 1 ))
1422 self .w = nn .Parameter (torch .randn (1 , d_model - 1 ))
1523 self .b = nn .Parameter (torch .randn (1 , d_model - 1 ))
1624
17- def forward (self , t ) :
18- linear = self .w0 * t + self .b0
19- periodic = torch .sin (self .w * t + self .b )
25+ def forward (self , time_steps : Tensor ) -> Tensor :
26+ linear = self .w0 * time_steps + self .b0
27+ periodic = torch .sin (self .w * time_steps + self .b )
2028 return torch .cat ([linear , periodic ], dim = - 1 )
2129
2230
23- # positionwise feedforward network
2431class PositionwiseFeedForward (nn .Module ):
25- def __init__ (self , d_model , hidden , drop_prob = 0.1 ):
32+ """
33+ Positionwise feedforward network.
34+
35+ >>> import torch
36+ >>> layer = PositionwiseFeedForward(8, 16)
37+ >>> x = torch.rand(4, 10, 8)
38+ >>> out = layer.forward(x)
39+ >>> out.shape
40+ torch.Size([4, 10, 8])
41+ """
42+ def __init__ (self , d_model : int , hidden : int , drop_prob : float = 0.1 ) -> None :
2643 super ().__init__ ()
2744 self .fc1 = nn .Linear (d_model , hidden )
2845 self .fc2 = nn .Linear (hidden , d_model )
2946 self .relu = nn .ReLU ()
3047 self .dropout = nn .Dropout (drop_prob )
3148
32- def forward (self , x ) :
33- x = self .fc1 (x )
49+ def forward (self , input_tensor : Tensor ) -> Tensor :
50+ x = self .fc1 (input_tensor )
3451 x = self .relu (x )
3552 x = self .dropout (x )
3653 return self .fc2 (x )
3754
3855
39- # scaled dot product attention
4056class ScaleDotProductAttention (nn .Module ):
41- def __init__ (self ):
57+ """
58+ Scaled dot product attention.
59+
60+ >>> import torch
61+ >>> attn = ScaleDotProductAttention()
62+ >>> q = torch.rand(2, 8, 10, 16)
63+ >>> k = torch.rand(2, 8, 10, 16)
64+ >>> v = torch.rand(2, 8, 10, 16)
65+ >>> ctx, attn_w = attn.forward(q, k, v)
66+ >>> ctx.shape
67+ torch.Size([2, 8, 10, 16])
68+ """
69+ def __init__ (self ) -> None :
4270 super ().__init__ ()
4371 self .softmax = nn .Softmax (dim = - 1 )
4472
45- def forward (self , q , k , v , mask = None ):
73+ def forward (self , q : Tensor , k : Tensor , v : Tensor , mask : Tensor = None ) -> tuple [ Tensor , Tensor ] :
4674 _ , _ , _ , d_k = k .size ()
4775 scores = (q @ k .transpose (2 , 3 )) / math .sqrt (d_k )
4876
@@ -54,9 +82,18 @@ def forward(self, q, k, v, mask=None):
5482 return context , attn
5583
5684
57- # multi head attention
5885class MultiHeadAttention (nn .Module ):
59- def __init__ (self , d_model , n_head ):
86+ """
87+ Multi-head attention.
88+
89+ >>> import torch
90+ >>> attn = MultiHeadAttention(16, 4)
91+ >>> q = torch.rand(2, 10, 16)
92+ >>> out = attn.forward(q, q, q)
93+ >>> out.shape
94+ torch.Size([2, 10, 16])
95+ """
96+ def __init__ (self , d_model : int , n_head : int ) -> None :
6097 super ().__init__ ()
6198 self .n_head = n_head
6299 self .attn = ScaleDotProductAttention ()
@@ -65,60 +102,99 @@ def __init__(self, d_model, n_head):
65102 self .w_v = nn .Linear (d_model , d_model )
66103 self .w_out = nn .Linear (d_model , d_model )
67104
68- def forward (self , q , k , v , mask = None ):
105+ def forward (self , q : Tensor , k : Tensor , v : Tensor , mask : Tensor = None ) -> Tensor :
69106 q , k , v = self .w_q (q ), self .w_k (k ), self .w_v (v )
70107 q , k , v = self .split_heads (q ), self .split_heads (k ), self .split_heads (v )
71108
72109 context , _ = self .attn (q , k , v , mask )
73110 out = self .w_out (self .concat_heads (context ))
74111 return out
75112
76- def split_heads (self , x ) :
113+ def split_heads (self , x : Tensor ) -> Tensor :
77114 batch , seq_len , d_model = x .size ()
78115 d_k = d_model // self .n_head
79116 return x .view (batch , seq_len , self .n_head , d_k ).transpose (1 , 2 )
80117
81- def concat_heads (self , x ) :
118+ def concat_heads (self , x : Tensor ) -> Tensor :
82119 batch , n_head , seq_len , d_k = x .size ()
83120 return x .transpose (1 , 2 ).contiguous ().view (batch , seq_len , n_head * d_k )
84121
85122
86- # Layer normalization
87123class LayerNorm (nn .Module ):
88- def __init__ (self , d_model , eps = 1e-12 ):
124+ """
125+ Layer normalization.
126+
127+ >>> import torch
128+ >>> ln = LayerNorm(8)
129+ >>> x = torch.rand(4, 10, 8)
130+ >>> out = ln.forward(x)
131+ >>> out.shape
132+ torch.Size([4, 10, 8])
133+ """
134+ def __init__ (self , d_model : int , eps : float = 1e-12 ) -> None :
89135 super ().__init__ ()
90136 self .gamma = nn .Parameter (torch .ones (d_model ))
91137 self .beta = nn .Parameter (torch .zeros (d_model ))
92138 self .eps = eps
93139
94- def forward (self , x ) :
95- mean = x .mean (- 1 , keepdim = True )
96- var = x .var (- 1 , unbiased = False , keepdim = True )
97- return self .gamma * (x - mean ) / torch .sqrt (var + self .eps ) + self .beta
140+ def forward (self , input_tensor : Tensor ) -> Tensor :
141+ mean = input_tensor .mean (- 1 , keepdim = True )
142+ var = input_tensor .var (- 1 , unbiased = False , keepdim = True )
143+ return self .gamma * (input_tensor - mean ) / torch .sqrt (var + self .eps ) + self .beta
98144
99145
100- # transformer encoder layer
101146class TransformerEncoderLayer (nn .Module ):
102- def __init__ (self , d_model , n_head , hidden_dim , drop_prob = 0.1 ):
147+ """
148+ Transformer encoder layer.
149+
150+ >>> import torch
151+ >>> layer = TransformerEncoderLayer(8, 2, 16)
152+ >>> x = torch.rand(4, 10, 8)
153+ >>> out = layer.forward(x)
154+ >>> out.shape
155+ torch.Size([4, 10, 8])
156+ """
157+ def __init__ (
158+ self ,
159+ d_model : int ,
160+ n_head : int ,
161+ hidden_dim : int ,
162+ drop_prob : float = 0.1 ,
163+ ) -> None :
103164 super ().__init__ ()
104165 self .self_attn = MultiHeadAttention (d_model , n_head )
105166 self .ffn = PositionwiseFeedForward (d_model , hidden_dim , drop_prob )
106167 self .norm1 = LayerNorm (d_model )
107168 self .norm2 = LayerNorm (d_model )
108169 self .dropout = nn .Dropout (drop_prob )
109170
110- def forward (self , x , mask = None ):
111- attn_out = self .self_attn (x , x , x , mask )
112- x = self .norm1 (x + self .dropout (attn_out ))
171+ def forward (self , input_tensor : Tensor , mask : Tensor = None ) -> Tensor :
172+ attn_out = self .self_attn (input_tensor , input_tensor , input_tensor , mask )
173+ x = self .norm1 (input_tensor + self .dropout (attn_out ))
113174 ffn_out = self .ffn (x )
114175 x = self .norm2 (x + self .dropout (ffn_out ))
115-
116176 return x
117177
118178
119- # encoder stack
120179class TransformerEncoder (nn .Module ):
121- def __init__ (self , d_model , n_head , hidden_dim , num_layers , drop_prob = 0.1 ):
180+ """
181+ Encoder stack.
182+
183+ >>> import torch
184+ >>> enc = TransformerEncoder(8, 2, 16, 2)
185+ >>> x = torch.rand(4, 10, 8)
186+ >>> out = enc.forward(x)
187+ >>> out.shape
188+ torch.Size([4, 10, 8])
189+ """
190+ def __init__ (
191+ self ,
192+ d_model : int ,
193+ n_head : int ,
194+ hidden_dim : int ,
195+ num_layers : int ,
196+ drop_prob : float = 0.1 ,
197+ ) -> None :
122198 super ().__init__ ()
123199 self .layers = nn .ModuleList (
124200 [
@@ -127,82 +203,81 @@ def __init__(self, d_model, n_head, hidden_dim, num_layers, drop_prob=0.1):
127203 ]
128204 )
129205
130- def forward (self , x , mask = None ):
206+ def forward (self , input_tensor : Tensor , mask : Tensor = None ) -> Tensor :
207+ x = input_tensor
131208 for layer in self .layers :
132209 x = layer (x , mask )
133210 return x
134211
135212
136- # attention pooling layer
137213class AttentionPooling (nn .Module ):
138- def __init__ (self , d_model ):
214+ """
215+ Attention pooling layer.
216+
217+ >>> import torch
218+ >>> pooling = AttentionPooling(8)
219+ >>> x = torch.rand(4, 10, 8)
220+ >>> pooled, weights = pooling.forward(x)
221+ >>> pooled.shape
222+ torch.Size([4, 8])
223+ >>> weights.shape
224+ torch.Size([4, 10])
225+ """
226+ def __init__ (self , d_model : int ) -> None :
139227 super ().__init__ ()
140228 self .attn_score = nn .Linear (d_model , 1 )
141229
142- def forward (self , x , mask = None ):
143- attn_weights = torch .softmax (self .attn_score (x ).squeeze (- 1 ), dim = - 1 )
230+ def forward (self , input_tensor : Tensor , mask : Tensor = None ) -> tuple [ Tensor , Tensor ] :
231+ attn_weights = torch .softmax (self .attn_score (input_tensor ).squeeze (- 1 ), dim = - 1 )
144232
145233 if mask is not None :
146234 attn_weights = attn_weights .masked_fill (mask == 0 , 0 )
147235 attn_weights = attn_weights / (attn_weights .sum (dim = 1 , keepdim = True ) + 1e-8 )
148236
149- pooled = torch .bmm (attn_weights .unsqueeze (1 ), x ).squeeze (1 )
237+ pooled = torch .bmm (attn_weights .unsqueeze (1 ), input_tensor ).squeeze (1 )
150238 return pooled , attn_weights
151239
152240
153- # transformer model
154-
155-
156241class EEGTransformer (nn .Module ):
242+ """
243+ EEG Transformer model.
244+
245+ >>> import torch
246+ >>> model = EEGTransformer(feature_dim=8)
247+ >>> x = torch.rand(2, 10, 8)
248+ >>> out, attn_w = model.forward(x)
249+ >>> out.shape
250+ torch.Size([2, 1])
251+ """
157252 def __init__ (
158253 self ,
159- feature_dim ,
160- d_model = 128 ,
161- n_head = 8 ,
162- hidden_dim = 512 ,
163- num_layers = 4 ,
164- drop_prob = 0.1 ,
165- output_dim = 1 ,
166- task_type = "regression" ,
167- ):
254+ feature_dim : int ,
255+ d_model : int = 128 ,
256+ n_head : int = 8 ,
257+ hidden_dim : int = 512 ,
258+ num_layers : int = 4 ,
259+ drop_prob : float = 0.1 ,
260+ output_dim : int = 1 ,
261+ task_type : str = "regression" ,
262+ ) -> None :
168263 super ().__init__ ()
169264 self .task_type = task_type
170265 self .input_proj = nn .Linear (feature_dim , d_model )
171-
172- # Time encoding for temporal understanding
173266 self .time2vec = Time2Vec (d_model )
174-
175- # Transformer encoder for sequence modeling
176267 self .encoder = TransformerEncoder (
177268 d_model , n_head , hidden_dim , num_layers , drop_prob
178269 )
179-
180- # Attention pooling to summarize time dimension
181270 self .pooling = AttentionPooling (d_model )
182-
183- # Final output layer
184271 self .output_layer = nn .Linear (d_model , output_dim )
185272
186- def forward (self , x , mask = None ):
187- b , t , _ = x .size ()
188-
189- # Create time indices and embed them
190- t_idx = torch .arange (t , device = x .device ).view (1 , t , 1 ).expand (b , t , 1 ).float ()
273+ def forward (self , input_tensor : Tensor , mask : Tensor = None ) -> tuple [Tensor , Tensor ]:
274+ b , t , _ = input_tensor .size ()
275+ t_idx = torch .arange (t , device = input_tensor .device ).view (1 , t , 1 ).expand (b , t , 1 ).float ()
191276 time_emb = self .time2vec (t_idx )
192-
193- # Add time embedding to feature projection
194- x = self .input_proj (x ) + time_emb
195-
196- # Pass through the Transformer encoder
277+ x = self .input_proj (input_tensor ) + time_emb
197278 x = self .encoder (x , mask )
198-
199- # Aggregate features across time with attention
200279 pooled , attn_weights = self .pooling (x , mask )
201-
202- # Final output (regression or classification)
203280 out = self .output_layer (pooled )
204-
205281 if self .task_type == "classification" :
206282 out = torch .softmax (out , dim = - 1 )
207-
208283 return out , attn_weights
0 commit comments