1- # imports
21import math
32
43import torch
@@ -16,7 +15,6 @@ class Time2Vec(nn.Module):
1615 >>> output.shape
1716 torch.Size([1, 3, 4])
1817 """
19-
2018 def __init__ (self , d_model : int ) -> None :
2119 super ().__init__ ()
2220 self .w0 = nn .Parameter (torch .randn (1 , 1 ))
@@ -41,7 +39,6 @@ class PositionwiseFeedForward(nn.Module):
4139 >>> out.shape
4240 torch.Size([4, 10, 8])
4341 """
44-
4542 def __init__ (self , d_model : int , hidden : int , drop_prob : float = 0.1 ) -> None :
4643 super ().__init__ ()
4744 self .fc1 = nn .Linear (d_model , hidden )
@@ -62,29 +59,32 @@ class ScaleDotProductAttention(nn.Module):
6259
6360 >>> import torch
6461 >>> attn = ScaleDotProductAttention()
65- >>> q = torch.rand(2, 8, 10, 16)
66- >>> k = torch.rand(2, 8, 10, 16)
67- >>> v = torch.rand(2, 8, 10, 16)
68- >>> ctx, attn_w = attn.forward(q, k, v )
62+ >>> query_tensor = torch.rand(2, 8, 10, 16)
63+ >>> key_tensor = torch.rand(2, 8, 10, 16)
64+ >>> value_tensor = torch.rand(2, 8, 10, 16)
65+ >>> ctx, attn_w = attn.forward(query_tensor, key_tensor, value_tensor )
6966 >>> ctx.shape
7067 torch.Size([2, 8, 10, 16])
7168 """
72-
7369 def __init__ (self ) -> None :
7470 super ().__init__ ()
7571 self .softmax = nn .Softmax (dim = - 1 )
7672
7773 def forward (
78- self , q : Tensor , k : Tensor , v : Tensor , mask : Tensor = None
74+ self ,
75+ query_tensor : Tensor ,
76+ key_tensor : Tensor ,
77+ value_tensor : Tensor ,
78+ mask : Tensor = None ,
7979 ) -> tuple [Tensor , Tensor ]:
80- _ , _ , _ , d_k = k .size ()
81- scores = (q @ k .transpose (2 , 3 )) / math .sqrt (d_k )
80+ _ , _ , _ , d_k = key_tensor .size ()
81+ scores = (query_tensor @ key_tensor .transpose (2 , 3 )) / math .sqrt (d_k )
8282
8383 if mask is not None :
8484 scores = scores .masked_fill (mask == 0 , - 1e9 )
8585
8686 attn = self .softmax (scores )
87- context = attn @ v
87+ context = attn @ value_tensor
8888 return context , attn
8989
9090
@@ -94,12 +94,11 @@ class MultiHeadAttention(nn.Module):
9494
9595 >>> import torch
9696 >>> attn = MultiHeadAttention(16, 4)
97- >>> q = torch.rand(2, 10, 16)
98- >>> out = attn.forward(q, q, q )
97+ >>> query_tensor = torch.rand(2, 10, 16)
98+ >>> out = attn.forward(query_tensor, query_tensor, query_tensor )
9999 >>> out.shape
100100 torch.Size([2, 10, 16])
101101 """
102-
103102 def __init__ (self , d_model : int , n_head : int ) -> None :
104103 super ().__init__ ()
105104 self .n_head = n_head
@@ -109,22 +108,34 @@ def __init__(self, d_model: int, n_head: int) -> None:
109108 self .w_v = nn .Linear (d_model , d_model )
110109 self .w_out = nn .Linear (d_model , d_model )
111110
112- def forward (self , q : Tensor , k : Tensor , v : Tensor , mask : Tensor = None ) -> Tensor :
113- q , k , v = self .w_q (q ), self .w_k (k ), self .w_v (v )
114- q , k , v = self .split_heads (q ), self .split_heads (k ), self .split_heads (v )
111+ def forward (
112+ self ,
113+ query_tensor : Tensor ,
114+ key_tensor : Tensor ,
115+ value_tensor : Tensor ,
116+ mask : Tensor = None ,
117+ ) -> Tensor :
118+ query_tensor , key_tensor , value_tensor = (
119+ self .w_q (query_tensor ),
120+ self .w_k (key_tensor ),
121+ self .w_v (value_tensor ),
122+ )
123+ query_tensor = self .split_heads (query_tensor )
124+ key_tensor = self .split_heads (key_tensor )
125+ value_tensor = self .split_heads (value_tensor )
115126
116- context , _ = self .attn (q , k , v , mask )
127+ context , _ = self .attn (query_tensor , key_tensor , value_tensor , mask )
117128 out = self .w_out (self .concat_heads (context ))
118129 return out
119130
120- def split_heads (self , x : Tensor ) -> Tensor :
121- batch , seq_len , d_model = x .size ()
131+ def split_heads (self , input_tensor : Tensor ) -> Tensor :
132+ batch , seq_len , d_model = input_tensor .size ()
122133 d_k = d_model // self .n_head
123- return x .view (batch , seq_len , self .n_head , d_k ).transpose (1 , 2 )
134+ return input_tensor .view (batch , seq_len , self .n_head , d_k ).transpose (1 , 2 )
124135
125- def concat_heads (self , x : Tensor ) -> Tensor :
126- batch , n_head , seq_len , d_k = x .size ()
127- return x .transpose (1 , 2 ).contiguous ().view (batch , seq_len , n_head * d_k )
136+ def concat_heads (self , input_tensor : Tensor ) -> Tensor :
137+ batch , n_head , seq_len , d_k = input_tensor .size ()
138+ return input_tensor .transpose (1 , 2 ).contiguous ().view (batch , seq_len , n_head * d_k )
128139
129140
130141class LayerNorm (nn .Module ):
@@ -138,7 +149,6 @@ class LayerNorm(nn.Module):
138149 >>> out.shape
139150 torch.Size([4, 10, 8])
140151 """
141-
142152 def __init__ (self , d_model : int , eps : float = 1e-12 ) -> None :
143153 super ().__init__ ()
144154 self .gamma = nn .Parameter (torch .ones (d_model ))
@@ -148,9 +158,7 @@ def __init__(self, d_model: int, eps: float = 1e-12) -> None:
148158 def forward (self , input_tensor : Tensor ) -> Tensor :
149159 mean = input_tensor .mean (- 1 , keepdim = True )
150160 var = input_tensor .var (- 1 , unbiased = False , keepdim = True )
151- return (
152- self .gamma * (input_tensor - mean ) / torch .sqrt (var + self .eps ) + self .beta
153- )
161+ return self .gamma * (input_tensor - mean ) / torch .sqrt (var + self .eps ) + self .beta
154162
155163
156164class TransformerEncoderLayer (nn .Module ):
@@ -164,7 +172,6 @@ class TransformerEncoderLayer(nn.Module):
164172 >>> out.shape
165173 torch.Size([4, 10, 8])
166174 """
167-
168175 def __init__ (
169176 self ,
170177 d_model : int ,
@@ -198,7 +205,6 @@ class TransformerEncoder(nn.Module):
198205 >>> out.shape
199206 torch.Size([4, 10, 8])
200207 """
201-
202208 def __init__ (
203209 self ,
204210 d_model : int ,
@@ -235,14 +241,11 @@ class AttentionPooling(nn.Module):
235241 >>> weights.shape
236242 torch.Size([4, 10])
237243 """
238-
239244 def __init__ (self , d_model : int ) -> None :
240245 super ().__init__ ()
241246 self .attn_score = nn .Linear (d_model , 1 )
242247
243- def forward (
244- self , input_tensor : Tensor , mask : Tensor = None
245- ) -> tuple [Tensor , Tensor ]:
248+ def forward (self , input_tensor : Tensor , mask : Tensor = None ) -> tuple [Tensor , Tensor ]:
246249 attn_weights = torch .softmax (self .attn_score (input_tensor ).squeeze (- 1 ), dim = - 1 )
247250
248251 if mask is not None :
@@ -264,7 +267,6 @@ class EEGTransformer(nn.Module):
264267 >>> out.shape
265268 torch.Size([2, 1])
266269 """
267-
268270 def __init__ (
269271 self ,
270272 feature_dim : int ,
@@ -286,16 +288,9 @@ def __init__(
286288 self .pooling = AttentionPooling (d_model )
287289 self .output_layer = nn .Linear (d_model , output_dim )
288290
289- def forward (
290- self , input_tensor : Tensor , mask : Tensor = None
291- ) -> tuple [Tensor , Tensor ]:
291+ def forward (self , input_tensor : Tensor , mask : Tensor = None ) -> tuple [Tensor , Tensor ]:
292292 b , t , _ = input_tensor .size ()
293- t_idx = (
294- torch .arange (t , device = input_tensor .device )
295- .view (1 , t , 1 )
296- .expand (b , t , 1 )
297- .float ()
298- )
293+ t_idx = torch .arange (t , device = input_tensor .device ).view (1 , t , 1 ).expand (b , t , 1 ).float ()
299294 time_emb = self .time2vec (t_idx )
300295 x = self .input_proj (input_tensor ) + time_emb
301296 x = self .encoder (x , mask )
0 commit comments