@@ -15,6 +15,7 @@ class Time2Vec(nn.Module):
1515 >>> output.shape
1616 torch.Size([1, 3, 4])
1717 """
18+
1819 def __init__ (self , d_model : int ) -> None :
1920 super ().__init__ ()
2021 self .w0 = nn .Parameter (torch .randn (1 , 1 ))
@@ -39,6 +40,7 @@ class PositionwiseFeedForward(nn.Module):
3940 >>> out.shape
4041 torch.Size([4, 10, 8])
4142 """
43+
4244 def __init__ (self , d_model : int , hidden : int , drop_prob : float = 0.1 ) -> None :
4345 super ().__init__ ()
4446 self .fc1 = nn .Linear (d_model , hidden )
@@ -66,6 +68,7 @@ class ScaleDotProductAttention(nn.Module):
6668 >>> ctx.shape
6769 torch.Size([2, 8, 10, 16])
6870 """
71+
6972 def __init__ (self ) -> None :
7073 super ().__init__ ()
7174 self .softmax = nn .Softmax (dim = - 1 )
@@ -99,6 +102,7 @@ class MultiHeadAttention(nn.Module):
99102 >>> out.shape
100103 torch.Size([2, 10, 16])
101104 """
105+
102106 def __init__ (self , d_model : int , n_head : int ) -> None :
103107 super ().__init__ ()
104108 self .n_head = n_head
@@ -135,7 +139,9 @@ def split_heads(self, input_tensor: Tensor) -> Tensor:
135139
136140 def concat_heads (self , input_tensor : Tensor ) -> Tensor :
137141 batch , n_head , seq_len , d_k = input_tensor .size ()
138- return input_tensor .transpose (1 , 2 ).contiguous ().view (batch , seq_len , n_head * d_k )
142+ return (
143+ input_tensor .transpose (1 , 2 ).contiguous ().view (batch , seq_len , n_head * d_k )
144+ )
139145
140146
141147class LayerNorm (nn .Module ):
@@ -149,6 +155,7 @@ class LayerNorm(nn.Module):
149155 >>> out.shape
150156 torch.Size([4, 10, 8])
151157 """
158+
152159 def __init__ (self , d_model : int , eps : float = 1e-12 ) -> None :
153160 super ().__init__ ()
154161 self .gamma = nn .Parameter (torch .ones (d_model ))
@@ -158,7 +165,9 @@ def __init__(self, d_model: int, eps: float = 1e-12) -> None:
158165 def forward (self , input_tensor : Tensor ) -> Tensor :
159166 mean = input_tensor .mean (- 1 , keepdim = True )
160167 var = input_tensor .var (- 1 , unbiased = False , keepdim = True )
161- return self .gamma * (input_tensor - mean ) / torch .sqrt (var + self .eps ) + self .beta
168+ return (
169+ self .gamma * (input_tensor - mean ) / torch .sqrt (var + self .eps ) + self .beta
170+ )
162171
163172
164173class TransformerEncoderLayer (nn .Module ):
@@ -172,6 +181,7 @@ class TransformerEncoderLayer(nn.Module):
172181 >>> out.shape
173182 torch.Size([4, 10, 8])
174183 """
184+
175185 def __init__ (
176186 self ,
177187 d_model : int ,
@@ -205,6 +215,7 @@ class TransformerEncoder(nn.Module):
205215 >>> out.shape
206216 torch.Size([4, 10, 8])
207217 """
218+
208219 def __init__ (
209220 self ,
210221 d_model : int ,
@@ -241,11 +252,14 @@ class AttentionPooling(nn.Module):
241252 >>> weights.shape
242253 torch.Size([4, 10])
243254 """
255+
244256 def __init__ (self , d_model : int ) -> None :
245257 super ().__init__ ()
246258 self .attn_score = nn .Linear (d_model , 1 )
247259
248- def forward (self , input_tensor : Tensor , mask : Tensor = None ) -> tuple [Tensor , Tensor ]:
260+ def forward (
261+ self , input_tensor : Tensor , mask : Tensor = None
262+ ) -> tuple [Tensor , Tensor ]:
249263 attn_weights = torch .softmax (self .attn_score (input_tensor ).squeeze (- 1 ), dim = - 1 )
250264
251265 if mask is not None :
@@ -267,6 +281,7 @@ class EEGTransformer(nn.Module):
267281 >>> out.shape
268282 torch.Size([2, 1])
269283 """
284+
270285 def __init__ (
271286 self ,
272287 feature_dim : int ,
@@ -288,9 +303,16 @@ def __init__(
288303 self .pooling = AttentionPooling (d_model )
289304 self .output_layer = nn .Linear (d_model , output_dim )
290305
291- def forward (self , input_tensor : Tensor , mask : Tensor = None ) -> tuple [Tensor , Tensor ]:
306+ def forward (
307+ self , input_tensor : Tensor , mask : Tensor = None
308+ ) -> tuple [Tensor , Tensor ]:
292309 b , t , _ = input_tensor .size ()
293- t_idx = torch .arange (t , device = input_tensor .device ).view (1 , t , 1 ).expand (b , t , 1 ).float ()
310+ t_idx = (
311+ torch .arange (t , device = input_tensor .device )
312+ .view (1 , t , 1 )
313+ .expand (b , t , 1 )
314+ .float ()
315+ )
294316 time_emb = self .time2vec (t_idx )
295317 x = self .input_proj (input_tensor ) + time_emb
296318 x = self .encoder (x , mask )
0 commit comments