Skip to content

Commit d30966c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d3a8f47 commit d30966c

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Time2Vec(nn.Module):
1515
>>> output.shape
1616
torch.Size([1, 3, 4])
1717
"""
18+
1819
def __init__(self, d_model: int) -> None:
1920
super().__init__()
2021
self.w0 = nn.Parameter(torch.randn(1, 1))
@@ -39,6 +40,7 @@ class PositionwiseFeedForward(nn.Module):
3940
>>> out.shape
4041
torch.Size([4, 10, 8])
4142
"""
43+
4244
def __init__(self, d_model: int, hidden: int, drop_prob: float = 0.1) -> None:
4345
super().__init__()
4446
self.fc1 = nn.Linear(d_model, hidden)
@@ -66,6 +68,7 @@ class ScaleDotProductAttention(nn.Module):
6668
>>> ctx.shape
6769
torch.Size([2, 8, 10, 16])
6870
"""
71+
6972
def __init__(self) -> None:
7073
super().__init__()
7174
self.softmax = nn.Softmax(dim=-1)
@@ -99,6 +102,7 @@ class MultiHeadAttention(nn.Module):
99102
>>> out.shape
100103
torch.Size([2, 10, 16])
101104
"""
105+
102106
def __init__(self, d_model: int, n_head: int) -> None:
103107
super().__init__()
104108
self.n_head = n_head
@@ -135,7 +139,9 @@ def split_heads(self, input_tensor: Tensor) -> Tensor:
135139

136140
def concat_heads(self, input_tensor: Tensor) -> Tensor:
137141
batch, n_head, seq_len, d_k = input_tensor.size()
138-
return input_tensor.transpose(1, 2).contiguous().view(batch, seq_len, n_head * d_k)
142+
return (
143+
input_tensor.transpose(1, 2).contiguous().view(batch, seq_len, n_head * d_k)
144+
)
139145

140146

141147
class LayerNorm(nn.Module):
@@ -149,6 +155,7 @@ class LayerNorm(nn.Module):
149155
>>> out.shape
150156
torch.Size([4, 10, 8])
151157
"""
158+
152159
def __init__(self, d_model: int, eps: float = 1e-12) -> None:
153160
super().__init__()
154161
self.gamma = nn.Parameter(torch.ones(d_model))
@@ -158,7 +165,9 @@ def __init__(self, d_model: int, eps: float = 1e-12) -> None:
158165
def forward(self, input_tensor: Tensor) -> Tensor:
159166
mean = input_tensor.mean(-1, keepdim=True)
160167
var = input_tensor.var(-1, unbiased=False, keepdim=True)
161-
return self.gamma * (input_tensor - mean) / torch.sqrt(var + self.eps) + self.beta
168+
return (
169+
self.gamma * (input_tensor - mean) / torch.sqrt(var + self.eps) + self.beta
170+
)
162171

163172

164173
class TransformerEncoderLayer(nn.Module):
@@ -172,6 +181,7 @@ class TransformerEncoderLayer(nn.Module):
172181
>>> out.shape
173182
torch.Size([4, 10, 8])
174183
"""
184+
175185
def __init__(
176186
self,
177187
d_model: int,
@@ -205,6 +215,7 @@ class TransformerEncoder(nn.Module):
205215
>>> out.shape
206216
torch.Size([4, 10, 8])
207217
"""
218+
208219
def __init__(
209220
self,
210221
d_model: int,
@@ -241,11 +252,14 @@ class AttentionPooling(nn.Module):
241252
>>> weights.shape
242253
torch.Size([4, 10])
243254
"""
255+
244256
def __init__(self, d_model: int) -> None:
245257
super().__init__()
246258
self.attn_score = nn.Linear(d_model, 1)
247259

248-
def forward(self, input_tensor: Tensor, mask: Tensor = None) -> tuple[Tensor, Tensor]:
260+
def forward(
261+
self, input_tensor: Tensor, mask: Tensor = None
262+
) -> tuple[Tensor, Tensor]:
249263
attn_weights = torch.softmax(self.attn_score(input_tensor).squeeze(-1), dim=-1)
250264

251265
if mask is not None:
@@ -267,6 +281,7 @@ class EEGTransformer(nn.Module):
267281
>>> out.shape
268282
torch.Size([2, 1])
269283
"""
284+
270285
def __init__(
271286
self,
272287
feature_dim: int,
@@ -288,9 +303,16 @@ def __init__(
288303
self.pooling = AttentionPooling(d_model)
289304
self.output_layer = nn.Linear(d_model, output_dim)
290305

291-
def forward(self, input_tensor: Tensor, mask: Tensor = None) -> tuple[Tensor, Tensor]:
306+
def forward(
307+
self, input_tensor: Tensor, mask: Tensor = None
308+
) -> tuple[Tensor, Tensor]:
292309
b, t, _ = input_tensor.size()
293-
t_idx = torch.arange(t, device=input_tensor.device).view(1, t, 1).expand(b, t, 1).float()
310+
t_idx = (
311+
torch.arange(t, device=input_tensor.device)
312+
.view(1, t, 1)
313+
.expand(b, t, 1)
314+
.float()
315+
)
294316
time_emb = self.time2vec(t_idx)
295317
x = self.input_proj(input_tensor) + time_emb
296318
x = self.encoder(x, mask)

0 commit comments

Comments
 (0)