Skip to content

Commit 47ba945

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1eca445 commit 47ba945

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Time2Vec(nn.Module):
1515
>>> output.shape
1616
torch.Size([1, 3, 4])
1717
"""
18+
1819
def __init__(self, d_model: int) -> None:
1920
super().__init__()
2021
self.w0 = nn.Parameter(torch.randn(1, 1))
@@ -39,6 +40,7 @@ class PositionwiseFeedForward(nn.Module):
3940
>>> out.shape
4041
torch.Size([4, 10, 8])
4142
"""
43+
4244
def __init__(self, d_model: int, hidden: int, drop_prob: float = 0.1) -> None:
4345
super().__init__()
4446
self.fc1 = nn.Linear(d_model, hidden)
@@ -66,11 +68,14 @@ class ScaleDotProductAttention(nn.Module):
6668
>>> ctx.shape
6769
torch.Size([2, 8, 10, 16])
6870
"""
71+
6972
def __init__(self) -> None:
7073
super().__init__()
7174
self.softmax = nn.Softmax(dim=-1)
7275

73-
def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None) -> tuple[Tensor, Tensor]:
76+
def forward(
77+
self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None
78+
) -> tuple[Tensor, Tensor]:
7479
_, _, _, d_k = k.size()
7580
scores = (q @ k.transpose(2, 3)) / math.sqrt(d_k)
7681

@@ -93,6 +98,7 @@ class MultiHeadAttention(nn.Module):
9398
>>> out.shape
9499
torch.Size([2, 10, 16])
95100
"""
101+
96102
def __init__(self, d_model: int, n_head: int) -> None:
97103
super().__init__()
98104
self.n_head = n_head
@@ -131,6 +137,7 @@ class LayerNorm(nn.Module):
131137
>>> out.shape
132138
torch.Size([4, 10, 8])
133139
"""
140+
134141
def __init__(self, d_model: int, eps: float = 1e-12) -> None:
135142
super().__init__()
136143
self.gamma = nn.Parameter(torch.ones(d_model))
@@ -140,7 +147,9 @@ def __init__(self, d_model: int, eps: float = 1e-12) -> None:
140147
def forward(self, input_tensor: Tensor) -> Tensor:
141148
mean = input_tensor.mean(-1, keepdim=True)
142149
var = input_tensor.var(-1, unbiased=False, keepdim=True)
143-
return self.gamma * (input_tensor - mean) / torch.sqrt(var + self.eps) + self.beta
150+
return (
151+
self.gamma * (input_tensor - mean) / torch.sqrt(var + self.eps) + self.beta
152+
)
144153

145154

146155
class TransformerEncoderLayer(nn.Module):
@@ -154,6 +163,7 @@ class TransformerEncoderLayer(nn.Module):
154163
>>> out.shape
155164
torch.Size([4, 10, 8])
156165
"""
166+
157167
def __init__(
158168
self,
159169
d_model: int,
@@ -187,6 +197,7 @@ class TransformerEncoder(nn.Module):
187197
>>> out.shape
188198
torch.Size([4, 10, 8])
189199
"""
200+
190201
def __init__(
191202
self,
192203
d_model: int,
@@ -223,11 +234,14 @@ class AttentionPooling(nn.Module):
223234
>>> weights.shape
224235
torch.Size([4, 10])
225236
"""
237+
226238
def __init__(self, d_model: int) -> None:
227239
super().__init__()
228240
self.attn_score = nn.Linear(d_model, 1)
229241

230-
def forward(self, input_tensor: Tensor, mask: Tensor = None) -> tuple[Tensor, Tensor]:
242+
def forward(
243+
self, input_tensor: Tensor, mask: Tensor = None
244+
) -> tuple[Tensor, Tensor]:
231245
attn_weights = torch.softmax(self.attn_score(input_tensor).squeeze(-1), dim=-1)
232246

233247
if mask is not None:
@@ -249,6 +263,7 @@ class EEGTransformer(nn.Module):
249263
>>> out.shape
250264
torch.Size([2, 1])
251265
"""
266+
252267
def __init__(
253268
self,
254269
feature_dim: int,
@@ -270,9 +285,16 @@ def __init__(
270285
self.pooling = AttentionPooling(d_model)
271286
self.output_layer = nn.Linear(d_model, output_dim)
272287

273-
def forward(self, input_tensor: Tensor, mask: Tensor = None) -> tuple[Tensor, Tensor]:
288+
def forward(
289+
self, input_tensor: Tensor, mask: Tensor = None
290+
) -> tuple[Tensor, Tensor]:
274291
b, t, _ = input_tensor.size()
275-
t_idx = torch.arange(t, device=input_tensor.device).view(1, t, 1).expand(b, t, 1).float()
292+
t_idx = (
293+
torch.arange(t, device=input_tensor.device)
294+
.view(1, t, 1)
295+
.expand(b, t, 1)
296+
.float()
297+
)
276298
time_emb = self.time2vec(t_idx)
277299
x = self.input_proj(input_tensor) + time_emb
278300
x = self.encoder(x, mask)

0 commit comments

Comments
 (0)