Skip to content

Commit 220270e

Browse files
authored
Create real_time_encoder_transformer.py
Created a real-time encoder only transformer model with Time2Vec as positional encoding along with generalised classifier layer for modelling realtime data like EEG.
1 parent e2a78d4 commit 220270e

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#imports
2+
import torch
3+
import torch.nn as nn
4+
import math
5+
#Time2Vec layer for positional encoding of real-time data like EEG
6+
class Time2Vec(nn.Module):
7+
#Encodes time steps into a continuous embedding space so to help the transformer learn temporal dependencies.
8+
def __init__(self, d_model):
9+
super().__init__()
10+
self.w0 = nn.Parameter(torch.randn(1, 1))
11+
self.b0 = nn.Parameter(torch.randn(1, 1))
12+
self.w = nn.Parameter(torch.randn(1, d_model - 1))
13+
self.b = nn.Parameter(torch.randn(1, d_model - 1))
14+
15+
def forward(self, t):
16+
linear = self.w0 * t + self.b0
17+
periodic = torch.sin(self.w * t + self.b)
18+
return torch.cat([linear, periodic], dim=-1)
19+
20+
#positionwise feedforward network
21+
class PositionwiseFeedForward(nn.Module):
22+
def __init__(self, d_model, hidden, drop_prob=0.1):
23+
super().__init__()
24+
self.fc1 = nn.Linear(d_model, hidden)
25+
self.fc2 = nn.Linear(hidden, d_model)
26+
self.relu = nn.ReLU()
27+
self.dropout = nn.Dropout(drop_prob)
28+
29+
def forward(self, x):
30+
x = self.fc1(x)
31+
x = self.relu(x)
32+
x = self.dropout(x)
33+
return self.fc2(x)
34+
#scaled dot product attention
35+
class ScaleDotProductAttention(nn.Module):
36+
def __init__(self):
37+
super().__init__()
38+
self.softmax = nn.Softmax(dim=-1)
39+
40+
def forward(self, q, k, v, mask=None):
41+
_, _, _, d_k = k.size()
42+
scores = (q @ k.transpose(2, 3)) / math.sqrt(d_k)
43+
44+
if mask is not None:
45+
scores = scores.masked_fill(mask == 0, -1e9)
46+
47+
attn = self.softmax(scores)
48+
context = attn @ v
49+
return context, attn
50+
#multi head attention
51+
class MultiHeadAttention(nn.Module):
52+
def __init__(self, d_model, n_head):
53+
super().__init__()
54+
self.n_head = n_head
55+
self.attn = ScaleDotProductAttention()
56+
self.w_q = nn.Linear(d_model, d_model)
57+
self.w_k = nn.Linear(d_model, d_model)
58+
self.w_v = nn.Linear(d_model, d_model)
59+
self.w_out = nn.Linear(d_model, d_model)
60+
61+
def forward(self, q, k, v, mask=None):
62+
q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
63+
q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)
64+
65+
context, _ = self.attn(q, k, v, mask)
66+
out = self.w_out(self.concat_heads(context))
67+
return out
68+
69+
def split_heads(self, x):
70+
batch, seq_len, d_model = x.size()
71+
d_k = d_model // self.n_head
72+
return x.view(batch, seq_len, self.n_head, d_k).transpose(1, 2)
73+
74+
def concat_heads(self, x):
75+
batch, n_head, seq_len, d_k = x.size()
76+
return x.transpose(1, 2).contiguous().view(batch, seq_len, n_head * d_k)
77+
78+
#Layer normalization
79+
class LayerNorm(nn.Module):
80+
def __init__(self, d_model, eps=1e-12):
81+
super().__init__()
82+
self.gamma = nn.Parameter(torch.ones(d_model))
83+
self.beta = nn.Parameter(torch.zeros(d_model))
84+
self.eps = eps
85+
86+
def forward(self, x):
87+
mean = x.mean(-1, keepdim=True)
88+
var = x.var(-1, unbiased=False, keepdim=True)
89+
return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta
90+
91+
#transformer encoder layer
92+
class TransformerEncoderLayer(nn.Module):
93+
def __init__(self, d_model, n_head, hidden_dim, drop_prob=0.1):
94+
super().__init__()
95+
self.self_attn = MultiHeadAttention(d_model, n_head)
96+
self.ffn = PositionwiseFeedForward(d_model, hidden_dim, drop_prob)
97+
self.norm1 = LayerNorm(d_model)
98+
self.norm2 = LayerNorm(d_model)
99+
self.dropout = nn.Dropout(drop_prob)
100+
101+
def forward(self, x, mask=None):
102+
attn_out = self.self_attn(x, x, x, mask)
103+
x = self.norm1(x + self.dropout(attn_out))
104+
ffn_out = self.ffn(x)
105+
x = self.norm2(x + self.dropout(ffn_out))
106+
107+
return x
108+
109+
#encoder stack
110+
class TransformerEncoder(nn.Module):
111+
def __init__(self, d_model, n_head, hidden_dim, num_layers, drop_prob=0.1):
112+
super().__init__()
113+
self.layers = nn.ModuleList([
114+
TransformerEncoderLayer(d_model, n_head, hidden_dim, drop_prob)
115+
for _ in range(num_layers)
116+
])
117+
118+
def forward(self, x, mask=None):
119+
for layer in self.layers:
120+
x = layer(x, mask)
121+
return x
122+
123+
124+
#attention pooling layer
125+
class AttentionPooling(nn.Module):
126+
def __init__(self, d_model):
127+
super().__init__()
128+
self.attn_score = nn.Linear(d_model, 1)
129+
130+
def forward(self, x, mask=None):
131+
attn_weights = torch.softmax(self.attn_score(x).squeeze(-1), dim=-1)
132+
133+
if mask is not None:
134+
attn_weights = attn_weights.masked_fill(mask == 0, 0)
135+
attn_weights = attn_weights / (attn_weights.sum(dim=1, keepdim=True) + 1e-8)
136+
137+
pooled = torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1)
138+
return pooled, attn_weights
139+
140+
# transformer model
141+
142+
class EEGTransformer(nn.Module):
143+
144+
def __init__(self, feature_dim, d_model=128, n_head=8, hidden_dim=512,
145+
num_layers=4, drop_prob=0.1, output_dim=1, task_type='regression'):
146+
super().__init__()
147+
self.task_type = task_type
148+
self.input_proj = nn.Linear(feature_dim, d_model)
149+
150+
# Time encoding for temporal understanding
151+
self.time2vec = Time2Vec(d_model)
152+
153+
# Transformer encoder for sequence modeling
154+
self.encoder = TransformerEncoder(d_model, n_head, hidden_dim, num_layers, drop_prob)
155+
156+
# Attention pooling to summarize time dimension
157+
self.pooling = AttentionPooling(d_model)
158+
159+
# Final output layer
160+
self.output_layer = nn.Linear(d_model, output_dim)
161+
162+
def forward(self, x, mask=None):
163+
164+
b, t, _ = x.size()
165+
166+
# Create time indices and embed them
167+
t_idx = torch.arange(t, device=x.device).view(1, t, 1).expand(b, t, 1).float()
168+
time_emb = self.time2vec(t_idx)
169+
170+
# Add time embedding to feature projection
171+
x = self.input_proj(x) + time_emb
172+
173+
# Pass through the Transformer encoder
174+
x = self.encoder(x, mask)
175+
176+
# Aggregate features across time with attention
177+
pooled, attn_weights = self.pooling(x, mask)
178+
179+
# Final output (regression or classification)
180+
out = self.output_layer(pooled)
181+
182+
if self.task_type == 'classification':
183+
out = torch.softmax(out, dim=-1)
184+
185+
return out, attn_weights

0 commit comments

Comments
 (0)