Skip to content

Commit 1eca445

Browse files
authored
Update real_time_encoder_transformer.py
1 parent 4a62b57 commit 1eca445

File tree

1 file changed

+149
-74
lines changed

1 file changed

+149
-74
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 149 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,73 @@
44
from torch import nn
55

66

7-
# Time2Vec layer for positional encoding of real-time data like EEG
87
class Time2Vec(nn.Module):
9-
# Encodes time steps into a continuous embedding space
10-
def __init__(self, d_model):
8+
"""
9+
Time2Vec layer for positional encoding of real-time data like EEG.
10+
11+
>>> import torch
12+
>>> layer = Time2Vec(4)
13+
>>> t = torch.ones(1, 3, 1)
14+
>>> output = layer.forward(t)
15+
>>> output.shape
16+
torch.Size([1, 3, 4])
17+
"""
18+
def __init__(self, d_model: int) -> None:
1119
super().__init__()
1220
self.w0 = nn.Parameter(torch.randn(1, 1))
1321
self.b0 = nn.Parameter(torch.randn(1, 1))
1422
self.w = nn.Parameter(torch.randn(1, d_model - 1))
1523
self.b = nn.Parameter(torch.randn(1, d_model - 1))
1624

17-
def forward(self, t):
18-
linear = self.w0 * t + self.b0
19-
periodic = torch.sin(self.w * t + self.b)
25+
def forward(self, time_steps: Tensor) -> Tensor:
26+
linear = self.w0 * time_steps + self.b0
27+
periodic = torch.sin(self.w * time_steps + self.b)
2028
return torch.cat([linear, periodic], dim=-1)
2129

2230

23-
# positionwise feedforward network
2431
class PositionwiseFeedForward(nn.Module):
25-
def __init__(self, d_model, hidden, drop_prob=0.1):
32+
"""
33+
Positionwise feedforward network.
34+
35+
>>> import torch
36+
>>> layer = PositionwiseFeedForward(8, 16)
37+
>>> x = torch.rand(4, 10, 8)
38+
>>> out = layer.forward(x)
39+
>>> out.shape
40+
torch.Size([4, 10, 8])
41+
"""
42+
def __init__(self, d_model: int, hidden: int, drop_prob: float = 0.1) -> None:
2643
super().__init__()
2744
self.fc1 = nn.Linear(d_model, hidden)
2845
self.fc2 = nn.Linear(hidden, d_model)
2946
self.relu = nn.ReLU()
3047
self.dropout = nn.Dropout(drop_prob)
3148

32-
def forward(self, x):
33-
x = self.fc1(x)
49+
def forward(self, input_tensor: Tensor) -> Tensor:
50+
x = self.fc1(input_tensor)
3451
x = self.relu(x)
3552
x = self.dropout(x)
3653
return self.fc2(x)
3754

3855

39-
# scaled dot product attention
4056
class ScaleDotProductAttention(nn.Module):
41-
def __init__(self):
57+
"""
58+
Scaled dot product attention.
59+
60+
>>> import torch
61+
>>> attn = ScaleDotProductAttention()
62+
>>> q = torch.rand(2, 8, 10, 16)
63+
>>> k = torch.rand(2, 8, 10, 16)
64+
>>> v = torch.rand(2, 8, 10, 16)
65+
>>> ctx, attn_w = attn.forward(q, k, v)
66+
>>> ctx.shape
67+
torch.Size([2, 8, 10, 16])
68+
"""
69+
def __init__(self) -> None:
4270
super().__init__()
4371
self.softmax = nn.Softmax(dim=-1)
4472

45-
def forward(self, q, k, v, mask=None):
73+
def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None) -> tuple[Tensor, Tensor]:
4674
_, _, _, d_k = k.size()
4775
scores = (q @ k.transpose(2, 3)) / math.sqrt(d_k)
4876

@@ -54,9 +82,18 @@ def forward(self, q, k, v, mask=None):
5482
return context, attn
5583

5684

57-
# multi head attention
5885
class MultiHeadAttention(nn.Module):
59-
def __init__(self, d_model, n_head):
86+
"""
87+
Multi-head attention.
88+
89+
>>> import torch
90+
>>> attn = MultiHeadAttention(16, 4)
91+
>>> q = torch.rand(2, 10, 16)
92+
>>> out = attn.forward(q, q, q)
93+
>>> out.shape
94+
torch.Size([2, 10, 16])
95+
"""
96+
def __init__(self, d_model: int, n_head: int) -> None:
6097
super().__init__()
6198
self.n_head = n_head
6299
self.attn = ScaleDotProductAttention()
@@ -65,60 +102,99 @@ def __init__(self, d_model, n_head):
65102
self.w_v = nn.Linear(d_model, d_model)
66103
self.w_out = nn.Linear(d_model, d_model)
67104

68-
def forward(self, q, k, v, mask=None):
105+
def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None) -> Tensor:
69106
q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
70107
q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)
71108

72109
context, _ = self.attn(q, k, v, mask)
73110
out = self.w_out(self.concat_heads(context))
74111
return out
75112

76-
def split_heads(self, x):
113+
def split_heads(self, x: Tensor) -> Tensor:
77114
batch, seq_len, d_model = x.size()
78115
d_k = d_model // self.n_head
79116
return x.view(batch, seq_len, self.n_head, d_k).transpose(1, 2)
80117

81-
def concat_heads(self, x):
118+
def concat_heads(self, x: Tensor) -> Tensor:
82119
batch, n_head, seq_len, d_k = x.size()
83120
return x.transpose(1, 2).contiguous().view(batch, seq_len, n_head * d_k)
84121

85122

86-
# Layer normalization
87123
class LayerNorm(nn.Module):
88-
def __init__(self, d_model, eps=1e-12):
124+
"""
125+
Layer normalization.
126+
127+
>>> import torch
128+
>>> ln = LayerNorm(8)
129+
>>> x = torch.rand(4, 10, 8)
130+
>>> out = ln.forward(x)
131+
>>> out.shape
132+
torch.Size([4, 10, 8])
133+
"""
134+
def __init__(self, d_model: int, eps: float = 1e-12) -> None:
89135
super().__init__()
90136
self.gamma = nn.Parameter(torch.ones(d_model))
91137
self.beta = nn.Parameter(torch.zeros(d_model))
92138
self.eps = eps
93139

94-
def forward(self, x):
95-
mean = x.mean(-1, keepdim=True)
96-
var = x.var(-1, unbiased=False, keepdim=True)
97-
return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta
140+
def forward(self, input_tensor: Tensor) -> Tensor:
141+
mean = input_tensor.mean(-1, keepdim=True)
142+
var = input_tensor.var(-1, unbiased=False, keepdim=True)
143+
return self.gamma * (input_tensor - mean) / torch.sqrt(var + self.eps) + self.beta
98144

99145

100-
# transformer encoder layer
101146
class TransformerEncoderLayer(nn.Module):
102-
def __init__(self, d_model, n_head, hidden_dim, drop_prob=0.1):
147+
"""
148+
Transformer encoder layer.
149+
150+
>>> import torch
151+
>>> layer = TransformerEncoderLayer(8, 2, 16)
152+
>>> x = torch.rand(4, 10, 8)
153+
>>> out = layer.forward(x)
154+
>>> out.shape
155+
torch.Size([4, 10, 8])
156+
"""
157+
def __init__(
158+
self,
159+
d_model: int,
160+
n_head: int,
161+
hidden_dim: int,
162+
drop_prob: float = 0.1,
163+
) -> None:
103164
super().__init__()
104165
self.self_attn = MultiHeadAttention(d_model, n_head)
105166
self.ffn = PositionwiseFeedForward(d_model, hidden_dim, drop_prob)
106167
self.norm1 = LayerNorm(d_model)
107168
self.norm2 = LayerNorm(d_model)
108169
self.dropout = nn.Dropout(drop_prob)
109170

110-
def forward(self, x, mask=None):
111-
attn_out = self.self_attn(x, x, x, mask)
112-
x = self.norm1(x + self.dropout(attn_out))
171+
def forward(self, input_tensor: Tensor, mask: Tensor = None) -> Tensor:
172+
attn_out = self.self_attn(input_tensor, input_tensor, input_tensor, mask)
173+
x = self.norm1(input_tensor + self.dropout(attn_out))
113174
ffn_out = self.ffn(x)
114175
x = self.norm2(x + self.dropout(ffn_out))
115-
116176
return x
117177

118178

119-
# encoder stack
120179
class TransformerEncoder(nn.Module):
121-
def __init__(self, d_model, n_head, hidden_dim, num_layers, drop_prob=0.1):
180+
"""
181+
Encoder stack.
182+
183+
>>> import torch
184+
>>> enc = TransformerEncoder(8, 2, 16, 2)
185+
>>> x = torch.rand(4, 10, 8)
186+
>>> out = enc.forward(x)
187+
>>> out.shape
188+
torch.Size([4, 10, 8])
189+
"""
190+
def __init__(
191+
self,
192+
d_model: int,
193+
n_head: int,
194+
hidden_dim: int,
195+
num_layers: int,
196+
drop_prob: float = 0.1,
197+
) -> None:
122198
super().__init__()
123199
self.layers = nn.ModuleList(
124200
[
@@ -127,82 +203,81 @@ def __init__(self, d_model, n_head, hidden_dim, num_layers, drop_prob=0.1):
127203
]
128204
)
129205

130-
def forward(self, x, mask=None):
206+
def forward(self, input_tensor: Tensor, mask: Tensor = None) -> Tensor:
207+
x = input_tensor
131208
for layer in self.layers:
132209
x = layer(x, mask)
133210
return x
134211

135212

136-
# attention pooling layer
137213
class AttentionPooling(nn.Module):
138-
def __init__(self, d_model):
214+
"""
215+
Attention pooling layer.
216+
217+
>>> import torch
218+
>>> pooling = AttentionPooling(8)
219+
>>> x = torch.rand(4, 10, 8)
220+
>>> pooled, weights = pooling.forward(x)
221+
>>> pooled.shape
222+
torch.Size([4, 8])
223+
>>> weights.shape
224+
torch.Size([4, 10])
225+
"""
226+
def __init__(self, d_model: int) -> None:
139227
super().__init__()
140228
self.attn_score = nn.Linear(d_model, 1)
141229

142-
def forward(self, x, mask=None):
143-
attn_weights = torch.softmax(self.attn_score(x).squeeze(-1), dim=-1)
230+
def forward(self, input_tensor: Tensor, mask: Tensor = None) -> tuple[Tensor, Tensor]:
231+
attn_weights = torch.softmax(self.attn_score(input_tensor).squeeze(-1), dim=-1)
144232

145233
if mask is not None:
146234
attn_weights = attn_weights.masked_fill(mask == 0, 0)
147235
attn_weights = attn_weights / (attn_weights.sum(dim=1, keepdim=True) + 1e-8)
148236

149-
pooled = torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
237+
pooled = torch.bmm(attn_weights.unsqueeze(1), input_tensor).squeeze(1)
150238
return pooled, attn_weights
151239

152240

153-
# transformer model
154-
155-
156241
class EEGTransformer(nn.Module):
242+
"""
243+
EEG Transformer model.
244+
245+
>>> import torch
246+
>>> model = EEGTransformer(feature_dim=8)
247+
>>> x = torch.rand(2, 10, 8)
248+
>>> out, attn_w = model.forward(x)
249+
>>> out.shape
250+
torch.Size([2, 1])
251+
"""
157252
def __init__(
158253
self,
159-
feature_dim,
160-
d_model=128,
161-
n_head=8,
162-
hidden_dim=512,
163-
num_layers=4,
164-
drop_prob=0.1,
165-
output_dim=1,
166-
task_type="regression",
167-
):
254+
feature_dim: int,
255+
d_model: int = 128,
256+
n_head: int = 8,
257+
hidden_dim: int = 512,
258+
num_layers: int = 4,
259+
drop_prob: float = 0.1,
260+
output_dim: int = 1,
261+
task_type: str = "regression",
262+
) -> None:
168263
super().__init__()
169264
self.task_type = task_type
170265
self.input_proj = nn.Linear(feature_dim, d_model)
171-
172-
# Time encoding for temporal understanding
173266
self.time2vec = Time2Vec(d_model)
174-
175-
# Transformer encoder for sequence modeling
176267
self.encoder = TransformerEncoder(
177268
d_model, n_head, hidden_dim, num_layers, drop_prob
178269
)
179-
180-
# Attention pooling to summarize time dimension
181270
self.pooling = AttentionPooling(d_model)
182-
183-
# Final output layer
184271
self.output_layer = nn.Linear(d_model, output_dim)
185272

186-
def forward(self, x, mask=None):
187-
b, t, _ = x.size()
188-
189-
# Create time indices and embed them
190-
t_idx = torch.arange(t, device=x.device).view(1, t, 1).expand(b, t, 1).float()
273+
def forward(self, input_tensor: Tensor, mask: Tensor = None) -> tuple[Tensor, Tensor]:
274+
b, t, _ = input_tensor.size()
275+
t_idx = torch.arange(t, device=input_tensor.device).view(1, t, 1).expand(b, t, 1).float()
191276
time_emb = self.time2vec(t_idx)
192-
193-
# Add time embedding to feature projection
194-
x = self.input_proj(x) + time_emb
195-
196-
# Pass through the Transformer encoder
277+
x = self.input_proj(input_tensor) + time_emb
197278
x = self.encoder(x, mask)
198-
199-
# Aggregate features across time with attention
200279
pooled, attn_weights = self.pooling(x, mask)
201-
202-
# Final output (regression or classification)
203280
out = self.output_layer(pooled)
204-
205281
if self.task_type == "classification":
206282
out = torch.softmax(out, dim=-1)
207-
208283
return out, attn_weights

0 commit comments

Comments
 (0)