Skip to content

Commit 23c5117

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 220270e commit 23c5117

File tree

1 file changed

+47
-24
lines changed

1 file changed

+47
-24
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
#imports
1+
# imports
22
import torch
33
import torch.nn as nn
44
import math
5-
#Time2Vec layer for positional encoding of real-time data like EEG
5+
6+
7+
# Time2Vec layer for positional encoding of real-time data like EEG
68
class Time2Vec(nn.Module):
7-
#Encodes time steps into a continuous embedding space so to help the transformer learn temporal dependencies.
9+
# Encodes time steps into a continuous embedding space so to help the transformer learn temporal dependencies.
810
def __init__(self, d_model):
911
super().__init__()
1012
self.w0 = nn.Parameter(torch.randn(1, 1))
@@ -13,11 +15,12 @@ def __init__(self, d_model):
1315
self.b = nn.Parameter(torch.randn(1, d_model - 1))
1416

1517
def forward(self, t):
16-
linear = self.w0 * t + self.b0
17-
periodic = torch.sin(self.w * t + self.b)
18-
return torch.cat([linear, periodic], dim=-1)
19-
20-
#positionwise feedforward network
18+
linear = self.w0 * t + self.b0
19+
periodic = torch.sin(self.w * t + self.b)
20+
return torch.cat([linear, periodic], dim=-1)
21+
22+
23+
# positionwise feedforward network
2124
class PositionwiseFeedForward(nn.Module):
2225
def __init__(self, d_model, hidden, drop_prob=0.1):
2326
super().__init__()
@@ -31,7 +34,9 @@ def forward(self, x):
3134
x = self.relu(x)
3235
x = self.dropout(x)
3336
return self.fc2(x)
34-
#scaled dot product attention
37+
38+
39+
# scaled dot product attention
3540
class ScaleDotProductAttention(nn.Module):
3641
def __init__(self):
3742
super().__init__()
@@ -47,7 +52,9 @@ def forward(self, q, k, v, mask=None):
4752
attn = self.softmax(scores)
4853
context = attn @ v
4954
return context, attn
50-
#multi head attention
55+
56+
57+
# multi head attention
5158
class MultiHeadAttention(nn.Module):
5259
def __init__(self, d_model, n_head):
5360
super().__init__()
@@ -75,7 +82,8 @@ def concat_heads(self, x):
7582
batch, n_head, seq_len, d_k = x.size()
7683
return x.transpose(1, 2).contiguous().view(batch, seq_len, n_head * d_k)
7784

78-
#Layer normalization
85+
86+
# Layer normalization
7987
class LayerNorm(nn.Module):
8088
def __init__(self, d_model, eps=1e-12):
8189
super().__init__()
@@ -88,7 +96,8 @@ def forward(self, x):
8896
var = x.var(-1, unbiased=False, keepdim=True)
8997
return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta
9098

91-
#transformer encoder layer
99+
100+
# transformer encoder layer
92101
class TransformerEncoderLayer(nn.Module):
93102
def __init__(self, d_model, n_head, hidden_dim, drop_prob=0.1):
94103
super().__init__()
@@ -106,22 +115,25 @@ def forward(self, x, mask=None):
106115

107116
return x
108117

109-
#encoder stack
118+
119+
# encoder stack
110120
class TransformerEncoder(nn.Module):
111121
def __init__(self, d_model, n_head, hidden_dim, num_layers, drop_prob=0.1):
112122
super().__init__()
113-
self.layers = nn.ModuleList([
114-
TransformerEncoderLayer(d_model, n_head, hidden_dim, drop_prob)
115-
for _ in range(num_layers)
116-
])
123+
self.layers = nn.ModuleList(
124+
[
125+
TransformerEncoderLayer(d_model, n_head, hidden_dim, drop_prob)
126+
for _ in range(num_layers)
127+
]
128+
)
117129

118130
def forward(self, x, mask=None):
119131
for layer in self.layers:
120132
x = layer(x, mask)
121133
return x
122134

123135

124-
#attention pooling layer
136+
# attention pooling layer
125137
class AttentionPooling(nn.Module):
126138
def __init__(self, d_model):
127139
super().__init__()
@@ -137,12 +149,22 @@ def forward(self, x, mask=None):
137149
pooled = torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
138150
return pooled, attn_weights
139151

152+
140153
# transformer model
141154

142-
class EEGTransformer(nn.Module):
143155

144-
def __init__(self, feature_dim, d_model=128, n_head=8, hidden_dim=512,
145-
num_layers=4, drop_prob=0.1, output_dim=1, task_type='regression'):
156+
class EEGTransformer(nn.Module):
157+
def __init__(
158+
self,
159+
feature_dim,
160+
d_model=128,
161+
n_head=8,
162+
hidden_dim=512,
163+
num_layers=4,
164+
drop_prob=0.1,
165+
output_dim=1,
166+
task_type="regression",
167+
):
146168
super().__init__()
147169
self.task_type = task_type
148170
self.input_proj = nn.Linear(feature_dim, d_model)
@@ -151,7 +173,9 @@ def __init__(self, feature_dim, d_model=128, n_head=8, hidden_dim=512,
151173
self.time2vec = Time2Vec(d_model)
152174

153175
# Transformer encoder for sequence modeling
154-
self.encoder = TransformerEncoder(d_model, n_head, hidden_dim, num_layers, drop_prob)
176+
self.encoder = TransformerEncoder(
177+
d_model, n_head, hidden_dim, num_layers, drop_prob
178+
)
155179

156180
# Attention pooling to summarize time dimension
157181
self.pooling = AttentionPooling(d_model)
@@ -160,7 +184,6 @@ def __init__(self, feature_dim, d_model=128, n_head=8, hidden_dim=512,
160184
self.output_layer = nn.Linear(d_model, output_dim)
161185

162186
def forward(self, x, mask=None):
163-
164187
b, t, _ = x.size()
165188

166189
# Create time indices and embed them
@@ -179,7 +202,7 @@ def forward(self, x, mask=None):
179202
# Final output (regression or classification)
180203
out = self.output_layer(pooled)
181204

182-
if self.task_type == 'classification':
205+
if self.task_type == "classification":
183206
out = torch.softmax(out, dim=-1)
184207

185208
return out, attn_weights

0 commit comments

Comments
 (0)