@@ -15,6 +15,7 @@ class Time2Vec(nn.Module):
1515 >>> output.shape
1616 torch.Size([1, 3, 4])
1717 """
18+
1819 def __init__ (self , d_model : int ) -> None :
1920 super ().__init__ ()
2021 self .w0 = nn .Parameter (torch .randn (1 , 1 ))
@@ -39,6 +40,7 @@ class PositionwiseFeedForward(nn.Module):
3940 >>> out.shape
4041 torch.Size([4, 10, 8])
4142 """
43+
4244 def __init__ (self , d_model : int , hidden : int , drop_prob : float = 0.1 ) -> None :
4345 super ().__init__ ()
4446 self .fc1 = nn .Linear (d_model , hidden )
@@ -66,11 +68,14 @@ class ScaleDotProductAttention(nn.Module):
6668 >>> ctx.shape
6769 torch.Size([2, 8, 10, 16])
6870 """
71+
6972 def __init__ (self ) -> None :
7073 super ().__init__ ()
7174 self .softmax = nn .Softmax (dim = - 1 )
7275
73- def forward (self , q : Tensor , k : Tensor , v : Tensor , mask : Tensor = None ) -> tuple [Tensor , Tensor ]:
76+ def forward (
77+ self , q : Tensor , k : Tensor , v : Tensor , mask : Tensor = None
78+ ) -> tuple [Tensor , Tensor ]:
7479 _ , _ , _ , d_k = k .size ()
7580 scores = (q @ k .transpose (2 , 3 )) / math .sqrt (d_k )
7681
@@ -93,6 +98,7 @@ class MultiHeadAttention(nn.Module):
9398 >>> out.shape
9499 torch.Size([2, 10, 16])
95100 """
101+
96102 def __init__ (self , d_model : int , n_head : int ) -> None :
97103 super ().__init__ ()
98104 self .n_head = n_head
@@ -131,6 +137,7 @@ class LayerNorm(nn.Module):
131137 >>> out.shape
132138 torch.Size([4, 10, 8])
133139 """
140+
134141 def __init__ (self , d_model : int , eps : float = 1e-12 ) -> None :
135142 super ().__init__ ()
136143 self .gamma = nn .Parameter (torch .ones (d_model ))
@@ -140,7 +147,9 @@ def __init__(self, d_model: int, eps: float = 1e-12) -> None:
140147 def forward (self , input_tensor : Tensor ) -> Tensor :
141148 mean = input_tensor .mean (- 1 , keepdim = True )
142149 var = input_tensor .var (- 1 , unbiased = False , keepdim = True )
143- return self .gamma * (input_tensor - mean ) / torch .sqrt (var + self .eps ) + self .beta
150+ return (
151+ self .gamma * (input_tensor - mean ) / torch .sqrt (var + self .eps ) + self .beta
152+ )
144153
145154
146155class TransformerEncoderLayer (nn .Module ):
@@ -154,6 +163,7 @@ class TransformerEncoderLayer(nn.Module):
154163 >>> out.shape
155164 torch.Size([4, 10, 8])
156165 """
166+
157167 def __init__ (
158168 self ,
159169 d_model : int ,
@@ -187,6 +197,7 @@ class TransformerEncoder(nn.Module):
187197 >>> out.shape
188198 torch.Size([4, 10, 8])
189199 """
200+
190201 def __init__ (
191202 self ,
192203 d_model : int ,
@@ -223,11 +234,14 @@ class AttentionPooling(nn.Module):
223234 >>> weights.shape
224235 torch.Size([4, 10])
225236 """
237+
226238 def __init__ (self , d_model : int ) -> None :
227239 super ().__init__ ()
228240 self .attn_score = nn .Linear (d_model , 1 )
229241
230- def forward (self , input_tensor : Tensor , mask : Tensor = None ) -> tuple [Tensor , Tensor ]:
242+ def forward (
243+ self , input_tensor : Tensor , mask : Tensor = None
244+ ) -> tuple [Tensor , Tensor ]:
231245 attn_weights = torch .softmax (self .attn_score (input_tensor ).squeeze (- 1 ), dim = - 1 )
232246
233247 if mask is not None :
@@ -249,6 +263,7 @@ class EEGTransformer(nn.Module):
249263 >>> out.shape
250264 torch.Size([2, 1])
251265 """
266+
252267 def __init__ (
253268 self ,
254269 feature_dim : int ,
@@ -270,9 +285,16 @@ def __init__(
270285 self .pooling = AttentionPooling (d_model )
271286 self .output_layer = nn .Linear (d_model , output_dim )
272287
273- def forward (self , input_tensor : Tensor , mask : Tensor = None ) -> tuple [Tensor , Tensor ]:
288+ def forward (
289+ self , input_tensor : Tensor , mask : Tensor = None
290+ ) -> tuple [Tensor , Tensor ]:
274291 b , t , _ = input_tensor .size ()
275- t_idx = torch .arange (t , device = input_tensor .device ).view (1 , t , 1 ).expand (b , t , 1 ).float ()
292+ t_idx = (
293+ torch .arange (t , device = input_tensor .device )
294+ .view (1 , t , 1 )
295+ .expand (b , t , 1 )
296+ .float ()
297+ )
276298 time_emb = self .time2vec (t_idx )
277299 x = self .input_proj (input_tensor ) + time_emb
278300 x = self .encoder (x , mask )
0 commit comments