1- #imports
1+ # imports
22import torch
33import torch .nn as nn
44import math
5- #Time2Vec layer for positional encoding of real-time data like EEG
5+
6+
7+ # Time2Vec layer for positional encoding of real-time data like EEG
68class Time2Vec (nn .Module ):
7- #Encodes time steps into a continuous embedding space so to help the transformer learn temporal dependencies.
9+ # Encodes time steps into a continuous embedding space so to help the transformer learn temporal dependencies.
810 def __init__ (self , d_model ):
911 super ().__init__ ()
1012 self .w0 = nn .Parameter (torch .randn (1 , 1 ))
@@ -13,11 +15,12 @@ def __init__(self, d_model):
1315 self .b = nn .Parameter (torch .randn (1 , d_model - 1 ))
1416
1517 def forward (self , t ):
16- linear = self .w0 * t + self .b0
17- periodic = torch .sin (self .w * t + self .b )
18- return torch .cat ([linear , periodic ], dim = - 1 )
19-
20- #positionwise feedforward network
18+ linear = self .w0 * t + self .b0
19+ periodic = torch .sin (self .w * t + self .b )
20+ return torch .cat ([linear , periodic ], dim = - 1 )
21+
22+
23+ # positionwise feedforward network
2124class PositionwiseFeedForward (nn .Module ):
2225 def __init__ (self , d_model , hidden , drop_prob = 0.1 ):
2326 super ().__init__ ()
@@ -31,7 +34,9 @@ def forward(self, x):
3134 x = self .relu (x )
3235 x = self .dropout (x )
3336 return self .fc2 (x )
34- #scaled dot product attention
37+
38+
39+ # scaled dot product attention
3540class ScaleDotProductAttention (nn .Module ):
3641 def __init__ (self ):
3742 super ().__init__ ()
@@ -47,7 +52,9 @@ def forward(self, q, k, v, mask=None):
4752 attn = self .softmax (scores )
4853 context = attn @ v
4954 return context , attn
50- #multi head attention
55+
56+
57+ # multi head attention
5158class MultiHeadAttention (nn .Module ):
5259 def __init__ (self , d_model , n_head ):
5360 super ().__init__ ()
@@ -75,7 +82,8 @@ def concat_heads(self, x):
7582 batch , n_head , seq_len , d_k = x .size ()
7683 return x .transpose (1 , 2 ).contiguous ().view (batch , seq_len , n_head * d_k )
7784
78- #Layer normalization
85+
86+ # Layer normalization
7987class LayerNorm (nn .Module ):
8088 def __init__ (self , d_model , eps = 1e-12 ):
8189 super ().__init__ ()
@@ -88,7 +96,8 @@ def forward(self, x):
8896 var = x .var (- 1 , unbiased = False , keepdim = True )
8997 return self .gamma * (x - mean ) / torch .sqrt (var + self .eps ) + self .beta
9098
91- #transformer encoder layer
99+
100+ # transformer encoder layer
92101class TransformerEncoderLayer (nn .Module ):
93102 def __init__ (self , d_model , n_head , hidden_dim , drop_prob = 0.1 ):
94103 super ().__init__ ()
@@ -106,22 +115,25 @@ def forward(self, x, mask=None):
106115
107116 return x
108117
109- #encoder stack
118+
119+ # encoder stack
110120class TransformerEncoder (nn .Module ):
111121 def __init__ (self , d_model , n_head , hidden_dim , num_layers , drop_prob = 0.1 ):
112122 super ().__init__ ()
113- self .layers = nn .ModuleList ([
114- TransformerEncoderLayer (d_model , n_head , hidden_dim , drop_prob )
115- for _ in range (num_layers )
116- ])
123+ self .layers = nn .ModuleList (
124+ [
125+ TransformerEncoderLayer (d_model , n_head , hidden_dim , drop_prob )
126+ for _ in range (num_layers )
127+ ]
128+ )
117129
118130 def forward (self , x , mask = None ):
119131 for layer in self .layers :
120132 x = layer (x , mask )
121133 return x
122134
123135
124- #attention pooling layer
136+ # attention pooling layer
125137class AttentionPooling (nn .Module ):
126138 def __init__ (self , d_model ):
127139 super ().__init__ ()
@@ -137,12 +149,22 @@ def forward(self, x, mask=None):
137149 pooled = torch .bmm (attn_weights .unsqueeze (1 ), x ).squeeze (1 )
138150 return pooled , attn_weights
139151
152+
140153# transformer model
141154
142- class EEGTransformer (nn .Module ):
143155
144- def __init__ (self , feature_dim , d_model = 128 , n_head = 8 , hidden_dim = 512 ,
145- num_layers = 4 , drop_prob = 0.1 , output_dim = 1 , task_type = 'regression' ):
156+ class EEGTransformer (nn .Module ):
157+ def __init__ (
158+ self ,
159+ feature_dim ,
160+ d_model = 128 ,
161+ n_head = 8 ,
162+ hidden_dim = 512 ,
163+ num_layers = 4 ,
164+ drop_prob = 0.1 ,
165+ output_dim = 1 ,
166+ task_type = "regression" ,
167+ ):
146168 super ().__init__ ()
147169 self .task_type = task_type
148170 self .input_proj = nn .Linear (feature_dim , d_model )
@@ -151,7 +173,9 @@ def __init__(self, feature_dim, d_model=128, n_head=8, hidden_dim=512,
151173 self .time2vec = Time2Vec (d_model )
152174
153175 # Transformer encoder for sequence modeling
154- self .encoder = TransformerEncoder (d_model , n_head , hidden_dim , num_layers , drop_prob )
176+ self .encoder = TransformerEncoder (
177+ d_model , n_head , hidden_dim , num_layers , drop_prob
178+ )
155179
156180 # Attention pooling to summarize time dimension
157181 self .pooling = AttentionPooling (d_model )
@@ -160,7 +184,6 @@ def __init__(self, feature_dim, d_model=128, n_head=8, hidden_dim=512,
160184 self .output_layer = nn .Linear (d_model , output_dim )
161185
162186 def forward (self , x , mask = None ):
163-
164187 b , t , _ = x .size ()
165188
166189 # Create time indices and embed them
@@ -179,7 +202,7 @@ def forward(self, x, mask=None):
179202 # Final output (regression or classification)
180203 out = self .output_layer (pooled )
181204
182- if self .task_type == ' classification' :
205+ if self .task_type == " classification" :
183206 out = torch .softmax (out , dim = - 1 )
184207
185208 return out , attn_weights
0 commit comments