Skip to content

Commit 53eff3c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 101e305 commit 53eff3c

File tree

1 file changed

+93
-21
lines changed

1 file changed

+93
-21
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 93 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
21
from __future__ import annotations
32
import math
43
from typing import Optional, Tuple
54

65
import numpy as np
76
import pandas as pd
87

8+
99
def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
1010
x_max = np.max(x, axis=axis, keepdims=True)
1111
e = np.exp(x - x_max)
@@ -18,6 +18,7 @@ def _stable_div(x: np.ndarray, denom: np.ndarray) -> np.ndarray:
1818

1919
# Time2Vec
2020

21+
2122
class Time2Vec:
2223
"""
2324
Time2Vec positional encoding (simple) for real-valued time steps.
@@ -51,8 +52,15 @@ def forward(self, time_steps: np.ndarray) -> np.ndarray:
5152

5253
# PositionwiseFeedForward
5354

55+
5456
class PositionwiseFeedForward:
55-
def __init__(self, d_model: int, hidden: int, drop_prob: float = 0.0, seed: Optional[int] = None):
57+
def __init__(
58+
self,
59+
d_model: int,
60+
hidden: int,
61+
drop_prob: float = 0.0,
62+
seed: Optional[int] = None,
63+
):
5664
if seed is not None:
5765
np.random.seed(seed)
5866
# simple linear layers (no dropout during forward-only inference, but kept shape)
@@ -70,11 +78,17 @@ def forward(self, x: np.ndarray) -> np.ndarray:
7078
return out
7179

7280

73-
7481
# Scaled Dot-Product Attention
7582

83+
7684
class ScaledDotProductAttention:
77-
def forward(self, q: np.ndarray, k: np.ndarray, v: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
85+
def forward(
86+
self,
87+
q: np.ndarray,
88+
k: np.ndarray,
89+
v: np.ndarray,
90+
mask: Optional[np.ndarray] = None,
91+
) -> Tuple[np.ndarray, np.ndarray]:
7892
"""
7993
q,k,v: shapes (b, n_head, seq_len, d_k)
8094
mask: optional boolean or 0/1 mask of shape (b, seq_len) or (b, 1, 1, seq_len)
@@ -90,7 +104,11 @@ def forward(self, q: np.ndarray, k: np.ndarray, v: np.ndarray, mask: Optional[np
90104
mask2 = mask[:, None, None, :] # (b,1,1,seq_len)
91105
elif mask.ndim == 3:
92106
# if provided as (b, n_head, seq_len) or (b, 1, seq_len)
93-
mask2 = mask[:, None, :, :] if mask.shape[1] != seq_len else mask[:, None, None, :]
107+
mask2 = (
108+
mask[:, None, :, :]
109+
if mask.shape[1] != seq_len
110+
else mask[:, None, None, :]
111+
)
94112
else:
95113
mask2 = mask
96114
# mask2==0 => masked
@@ -103,6 +121,7 @@ def forward(self, q: np.ndarray, k: np.ndarray, v: np.ndarray, mask: Optional[np
103121

104122
# MultiHeadAttention
105123

124+
106125
class MultiHeadAttention:
107126
def __init__(self, d_model: int, n_head: int, seed: Optional[int] = None):
108127
if d_model % n_head != 0:
@@ -114,13 +133,21 @@ def __init__(self, d_model: int, n_head: int, seed: Optional[int] = None):
114133
self.d_k = d_model // n_head
115134

116135
# weight matrices for q,k,v and output
117-
self.w_q = np.random.randn(d_model, d_model) * math.sqrt(2.0 / (d_model + d_model))
136+
self.w_q = np.random.randn(d_model, d_model) * math.sqrt(
137+
2.0 / (d_model + d_model)
138+
)
118139
self.b_q = np.zeros((d_model,))
119-
self.w_k = np.random.randn(d_model, d_model) * math.sqrt(2.0 / (d_model + d_model))
140+
self.w_k = np.random.randn(d_model, d_model) * math.sqrt(
141+
2.0 / (d_model + d_model)
142+
)
120143
self.b_k = np.zeros((d_model,))
121-
self.w_v = np.random.randn(d_model, d_model) * math.sqrt(2.0 / (d_model + d_model))
144+
self.w_v = np.random.randn(d_model, d_model) * math.sqrt(
145+
2.0 / (d_model + d_model)
146+
)
122147
self.b_v = np.zeros((d_model,))
123-
self.w_out = np.random.randn(d_model, d_model) * math.sqrt(2.0 / (d_model + d_model))
148+
self.w_out = np.random.randn(d_model, d_model) * math.sqrt(
149+
2.0 / (d_model + d_model)
150+
)
124151
self.b_out = np.zeros((d_model,))
125152

126153
self.attn = ScaledDotProductAttention()
@@ -139,7 +166,13 @@ def _concat_heads(self, x: np.ndarray) -> np.ndarray:
139166
b, n_head, seq_len, d_k = x.shape
140167
return x.transpose(0, 2, 1, 3).reshape(b, seq_len, n_head * d_k)
141168

142-
def forward(self, query: np.ndarray, key: np.ndarray, value: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
169+
def forward(
170+
self,
171+
query: np.ndarray,
172+
key: np.ndarray,
173+
value: np.ndarray,
174+
mask: Optional[np.ndarray] = None,
175+
) -> Tuple[np.ndarray, np.ndarray]:
143176
"""
144177
query/key/value: (b, seq_len, d_model)
145178
returns: out (b, seq_len, d_model), attn_weights (b, n_head, seq_len, seq_len)
@@ -157,9 +190,9 @@ def forward(self, query: np.ndarray, key: np.ndarray, value: np.ndarray, mask: O
157190
return out, attn
158191

159192

160-
161193
# LayerNorm
162194

195+
163196
class LayerNorm:
164197
def __init__(self, d_model: int, eps: float = 1e-12):
165198
self.gamma = np.ones((d_model,))
@@ -173,10 +206,14 @@ def forward(self, x: np.ndarray) -> np.ndarray:
173206
x_norm = (x - mean) / np.sqrt(var + self.eps)
174207
return self.gamma * x_norm + self.beta
175208

209+
176210
# TransformerEncoderLayer
177211

212+
178213
class TransformerEncoderLayer:
179-
def __init__(self, d_model: int, n_head: int, hidden_dim: int, seed: Optional[int] = None):
214+
def __init__(
215+
self, d_model: int, n_head: int, hidden_dim: int, seed: Optional[int] = None
216+
):
180217
self.self_attn = MultiHeadAttention(d_model, n_head, seed=seed)
181218
self.ffn = PositionwiseFeedForward(d_model, hidden_dim, seed=seed)
182219
self.norm1 = LayerNorm(d_model)
@@ -193,26 +230,41 @@ def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarra
193230

194231
# TransformerEncoder (stack)
195232

233+
196234
class TransformerEncoder:
197-
def __init__(self, d_model: int, n_head: int, hidden_dim: int, num_layers: int, seed: Optional[int] = None):
198-
self.layers = [TransformerEncoderLayer(d_model, n_head, hidden_dim, seed=seed) for _ in range(num_layers)]
235+
def __init__(
236+
self,
237+
d_model: int,
238+
n_head: int,
239+
hidden_dim: int,
240+
num_layers: int,
241+
seed: Optional[int] = None,
242+
):
243+
self.layers = [
244+
TransformerEncoderLayer(d_model, n_head, hidden_dim, seed=seed)
245+
for _ in range(num_layers)
246+
]
199247

200248
def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray:
201249
out = x
202250
for layer in self.layers:
203251
out = layer.forward(out, mask)
204252
return out
205253

254+
206255
# AttentionPooling
207256

257+
208258
class AttentionPooling:
209259
def __init__(self, d_model: int, seed: Optional[int] = None):
210260
if seed is not None:
211261
np.random.seed(seed)
212262
self.w = np.random.randn(d_model) * math.sqrt(2.0 / d_model)
213263
self.b = 0.0
214264

215-
def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
265+
def forward(
266+
self, x: np.ndarray, mask: Optional[np.ndarray] = None
267+
) -> Tuple[np.ndarray, np.ndarray]:
216268
"""
217269
x: (b, seq_len, d_model)
218270
mask: (b, seq_len) where 1 = valid, 0 = pad
@@ -228,8 +280,10 @@ def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.
228280
pooled = np.matmul(weights[:, None, :], x).squeeze(1) # (b, d_model)
229281
return pooled, weights
230282

283+
231284
# EEGTransformer (forward-only)
232285

286+
233287
class EEGTransformer:
234288
def __init__(
235289
self,
@@ -248,21 +302,29 @@ def __init__(
248302
self.d_model = d_model
249303
self.task_type = task_type
250304
# input projection
251-
self.w_in = np.random.randn(feature_dim, d_model) * math.sqrt(2.0 / (feature_dim + d_model))
305+
self.w_in = np.random.randn(feature_dim, d_model) * math.sqrt(
306+
2.0 / (feature_dim + d_model)
307+
)
252308
self.b_in = np.zeros((d_model,))
253309
# time embedding
254310
self.time2vec = Time2Vec(d_model, seed=seed)
255-
self.encoder = TransformerEncoder(d_model, n_head, hidden_dim, num_layers, seed=seed)
311+
self.encoder = TransformerEncoder(
312+
d_model, n_head, hidden_dim, num_layers, seed=seed
313+
)
256314
self.pooling = AttentionPooling(d_model, seed=seed)
257315
# output
258-
self.w_out = np.random.randn(d_model, output_dim) * math.sqrt(2.0 / (d_model + output_dim))
316+
self.w_out = np.random.randn(d_model, output_dim) * math.sqrt(
317+
2.0 / (d_model + output_dim)
318+
)
259319
self.b_out = np.zeros((output_dim,))
260320

261321
def _input_proj(self, x: np.ndarray) -> np.ndarray:
262322
# x: (b, seq_len, feature_dim) -> (b, seq_len, d_model)
263323
return np.tensordot(x, self.w_in, axes=([2], [0])) + self.b_in
264324

265-
def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
325+
def forward(
326+
self, x: np.ndarray, mask: Optional[np.ndarray] = None
327+
) -> Tuple[np.ndarray, np.ndarray]:
266328
"""
267329
x: (b, seq_len, feature_dim)
268330
mask: optional (b, seq_len) 1=valid,0=pad
@@ -276,7 +338,9 @@ def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.
276338
x_proj = self._input_proj(x) + time_emb # broadcast add -> (b,t,d_model)
277339
enc = self.encoder.forward(x_proj, mask)
278340
pooled, attn_weights = self.pooling.forward(enc, mask)
279-
out = np.tensordot(pooled, self.w_out, axes=([1], [0])) + self.b_out # (b,output_dim)
341+
out = (
342+
np.tensordot(pooled, self.w_out, axes=([1], [0])) + self.b_out
343+
) # (b,output_dim)
280344
if self.task_type == "classification":
281345
out = _softmax(out, axis=-1)
282346
return out, attn_weights
@@ -292,7 +356,15 @@ def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.
292356
rng = np.random.RandomState(42)
293357
X = rng.randn(batch, seq_len, feature_dim).astype(float)
294358

295-
model = EEGTransformer(feature_dim=feature_dim, d_model=32, n_head=4, hidden_dim=64, num_layers=2, output_dim=1, seed=0)
359+
model = EEGTransformer(
360+
feature_dim=feature_dim,
361+
d_model=32,
362+
n_head=4,
363+
hidden_dim=64,
364+
num_layers=2,
365+
output_dim=1,
366+
seed=0,
367+
)
296368
out, attn_weights = model.forward(X)
297369
print("Output shape:", out.shape)
298370
print("Output:", out)

0 commit comments

Comments
 (0)