Skip to content

Commit f10a2ea

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0fc2b8e commit f10a2ea

File tree

1 file changed

+95
-24
lines changed

1 file changed

+95
-24
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 95 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _stable_div(numerator: np.ndarray, denominator: np.ndarray) -> np.ndarray:
2020
# 🔹 Time2Vec
2121
# -------------------------------
2222

23+
2324
class Time2Vec:
2425
def __init__(self, d_model: int, seed: Optional[int] = None) -> None:
2526
if d_model < 2:
@@ -63,12 +64,23 @@ def forward(self, time_indices: np.ndarray) -> np.ndarray:
6364
# 🔹 Positionwise FeedForward
6465
# -------------------------------
6566

67+
6668
class PositionwiseFeedForward:
67-
def __init__(self, d_model: int, hidden_dim: int, drop_prob: float = 0.0, seed: Optional[int] = None) -> None:
69+
def __init__(
70+
self,
71+
d_model: int,
72+
hidden_dim: int,
73+
drop_prob: float = 0.0,
74+
seed: Optional[int] = None,
75+
) -> None:
6876
self.rng = np.random.default_rng(seed)
69-
self.w1: np.ndarray = self.rng.standard_normal((d_model, hidden_dim)) * math.sqrt(2.0 / (d_model + hidden_dim))
77+
self.w1: np.ndarray = self.rng.standard_normal(
78+
(d_model, hidden_dim)
79+
) * math.sqrt(2.0 / (d_model + hidden_dim))
7080
self.b1: np.ndarray = np.zeros((hidden_dim,))
71-
self.w2: np.ndarray = self.rng.standard_normal((hidden_dim, d_model)) * math.sqrt(2.0 / (hidden_dim + d_model))
81+
self.w2: np.ndarray = self.rng.standard_normal(
82+
(hidden_dim, d_model)
83+
) * math.sqrt(2.0 / (hidden_dim + d_model))
7284
self.b2: np.ndarray = np.zeros((d_model,))
7385

7486
def forward(self, input_tensor: np.ndarray) -> np.ndarray:
@@ -82,6 +94,7 @@ def forward(self, input_tensor: np.ndarray) -> np.ndarray:
8294
# 🔹 Scaled Dot-Product Attention
8395
# -------------------------------
8496

97+
8598
class ScaledDotProductAttention:
8699
def forward(
87100
self,
@@ -97,7 +110,11 @@ def forward(
97110
if mask.ndim == 2:
98111
mask_reshaped = mask[:, None, None, :]
99112
elif mask.ndim == 3:
100-
mask_reshaped = mask[:, None, :, :] if mask.shape[1] != seq_len else mask[:, None, None, :]
113+
mask_reshaped = (
114+
mask[:, None, :, :]
115+
if mask.shape[1] != seq_len
116+
else mask[:, None, None, :]
117+
)
101118
else:
102119
mask_reshaped = mask
103120
scores = np.where(mask_reshaped == 0, -1e9, scores)
@@ -111,6 +128,7 @@ def forward(
111128
# 🔹 Multi-Head Attention
112129
# -------------------------------
113130

131+
114132
class MultiHeadAttention:
115133
def __init__(self, d_model: int, n_head: int, seed: Optional[int] = None) -> None:
116134
if d_model % n_head != 0:
@@ -121,27 +139,41 @@ def __init__(self, d_model: int, n_head: int, seed: Optional[int] = None) -> Non
121139
self.n_head = n_head
122140
self.d_k = d_model // n_head
123141

124-
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / (d_model + d_model))
142+
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
143+
2.0 / (d_model + d_model)
144+
)
125145
self.b_q = np.zeros((d_model,))
126-
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / (d_model + d_model))
146+
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
147+
2.0 / (d_model + d_model)
148+
)
127149
self.b_k = np.zeros((d_model,))
128-
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / (d_model + d_model))
150+
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
151+
2.0 / (d_model + d_model)
152+
)
129153
self.b_v = np.zeros((d_model,))
130-
self.w_out = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / (d_model + d_model))
154+
self.w_out = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
155+
2.0 / (d_model + d_model)
156+
)
131157
self.b_out = np.zeros((d_model,))
132158

133159
self.attn = ScaledDotProductAttention()
134160

135-
def _linear(self, input_tensor: np.ndarray, weight: np.ndarray, bias: np.ndarray) -> np.ndarray:
161+
def _linear(
162+
self, input_tensor: np.ndarray, weight: np.ndarray, bias: np.ndarray
163+
) -> np.ndarray:
136164
return np.tensordot(input_tensor, weight, axes=([2], [0])) + bias
137165

138166
def _split_heads(self, input_tensor: np.ndarray) -> np.ndarray:
139167
batch_size, seq_len, _ = input_tensor.shape
140-
return input_tensor.reshape(batch_size, seq_len, self.n_head, self.d_k).transpose(0, 2, 1, 3)
168+
return input_tensor.reshape(
169+
batch_size, seq_len, self.n_head, self.d_k
170+
).transpose(0, 2, 1, 3)
141171

142172
def _concat_heads(self, input_tensor: np.ndarray) -> np.ndarray:
143173
batch_size, n_head, seq_len, d_k = input_tensor.shape
144-
return input_tensor.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, n_head * d_k)
174+
return input_tensor.transpose(0, 2, 1, 3).reshape(
175+
batch_size, seq_len, n_head * d_k
176+
)
145177

146178
def forward(
147179
self,
@@ -174,6 +206,7 @@ def forward(
174206
# 🔹 LayerNorm
175207
# -------------------------------
176208

209+
177210
class LayerNorm:
178211
def __init__(self, d_model: int, eps: float = 1e-12) -> None:
179212
self.gamma: np.ndarray = np.ones((d_model,))
@@ -185,18 +218,25 @@ def forward(self, input_tensor: np.ndarray) -> np.ndarray:
185218
var = np.mean((input_tensor - mean) ** 2, axis=-1, keepdims=True)
186219
normalized_tensor = (input_tensor - mean) / np.sqrt(var + self.eps)
187220
return self.gamma * normalized_tensor + self.beta
221+
222+
188223
# -------------------------------
189224
# 🔹 Transformer Encoder Layer
190225
# -------------------------------
191226

227+
192228
class TransformerEncoderLayer:
193-
def __init__(self, d_model: int, n_head: int, hidden_dim: int, seed: Optional[int] = None) -> None:
229+
def __init__(
230+
self, d_model: int, n_head: int, hidden_dim: int, seed: Optional[int] = None
231+
) -> None:
194232
self.self_attn = MultiHeadAttention(d_model, n_head, seed=seed)
195233
self.ffn = PositionwiseFeedForward(d_model, hidden_dim, seed=seed)
196234
self.norm1 = LayerNorm(d_model)
197235
self.norm2 = LayerNorm(d_model)
198236

199-
def forward(self, encoded_input: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray:
237+
def forward(
238+
self, encoded_input: np.ndarray, mask: Optional[np.ndarray] = None
239+
) -> np.ndarray:
200240
"""
201241
Forward pass for one encoder layer.
202242
@@ -220,7 +260,9 @@ def forward(self, encoded_input: np.ndarray, mask: Optional[np.ndarray] = None)
220260
>>> out.shape
221261
(1, 3, 4)
222262
"""
223-
attn_output, _ = self.self_attn.forward(encoded_input, encoded_input, encoded_input, mask)
263+
attn_output, _ = self.self_attn.forward(
264+
encoded_input, encoded_input, encoded_input, mask
265+
)
224266
out1 = self.norm1.forward(encoded_input + attn_output)
225267
ffn_output = self.ffn.forward(out1)
226268
out2 = self.norm2.forward(out1 + ffn_output)
@@ -231,11 +273,24 @@ def forward(self, encoded_input: np.ndarray, mask: Optional[np.ndarray] = None)
231273
# 🔹 Transformer Encoder Stack
232274
# -------------------------------
233275

276+
234277
class TransformerEncoder:
235-
def __init__(self, d_model: int, n_head: int, hidden_dim: int, num_layers: int, seed: Optional[int] = None) -> None:
236-
self.layers = [TransformerEncoderLayer(d_model, n_head, hidden_dim, seed=seed) for _ in range(num_layers)]
278+
def __init__(
279+
self,
280+
d_model: int,
281+
n_head: int,
282+
hidden_dim: int,
283+
num_layers: int,
284+
seed: Optional[int] = None,
285+
) -> None:
286+
self.layers = [
287+
TransformerEncoderLayer(d_model, n_head, hidden_dim, seed=seed)
288+
for _ in range(num_layers)
289+
]
237290

238-
def forward(self, encoded_input: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray:
291+
def forward(
292+
self, encoded_input: np.ndarray, mask: Optional[np.ndarray] = None
293+
) -> np.ndarray:
239294
"""
240295
Forward pass for encoder stack.
241296
@@ -269,13 +324,18 @@ def forward(self, encoded_input: np.ndarray, mask: Optional[np.ndarray] = None)
269324
# 🔹 Attention Pooling
270325
# -------------------------------
271326

327+
272328
class AttentionPooling:
273329
def __init__(self, d_model: int, seed: Optional[int] = None) -> None:
274330
self.rng = np.random.default_rng(seed)
275-
self.w: np.ndarray = self.rng.standard_normal(d_model) * math.sqrt(2.0 / d_model)
331+
self.w: np.ndarray = self.rng.standard_normal(d_model) * math.sqrt(
332+
2.0 / d_model
333+
)
276334
self.b: float = 0.0
277335

278-
def forward(self, encoded_features: np.ndarray, mask: Optional[np.ndarray] = None) -> tuple[np.ndarray, np.ndarray]:
336+
def forward(
337+
self, encoded_features: np.ndarray, mask: Optional[np.ndarray] = None
338+
) -> tuple[np.ndarray, np.ndarray]:
279339
"""
280340
Attention-based pooling.
281341
@@ -315,6 +375,7 @@ def forward(self, encoded_features: np.ndarray, mask: Optional[np.ndarray] = Non
315375
# 🔹 EEG Transformer
316376
# -------------------------------
317377

378+
318379
class EEGTransformer:
319380
def __init__(
320381
self,
@@ -332,20 +393,28 @@ def __init__(
332393
self.d_model = d_model
333394
self.task_type = task_type
334395

335-
self.w_in: np.ndarray = self.rng.standard_normal((feature_dim, d_model)) * math.sqrt(2.0 / (feature_dim + d_model))
396+
self.w_in: np.ndarray = self.rng.standard_normal(
397+
(feature_dim, d_model)
398+
) * math.sqrt(2.0 / (feature_dim + d_model))
336399
self.b_in: np.ndarray = np.zeros((d_model,))
337400

338401
self.time2vec = Time2Vec(d_model, seed=seed)
339-
self.encoder = TransformerEncoder(d_model, n_head, hidden_dim, num_layers, seed=seed)
402+
self.encoder = TransformerEncoder(
403+
d_model, n_head, hidden_dim, num_layers, seed=seed
404+
)
340405
self.pooling = AttentionPooling(d_model, seed=seed)
341406

342-
self.w_out: np.ndarray = self.rng.standard_normal((d_model, output_dim)) * math.sqrt(2.0 / (d_model + output_dim))
407+
self.w_out: np.ndarray = self.rng.standard_normal(
408+
(d_model, output_dim)
409+
) * math.sqrt(2.0 / (d_model + output_dim))
343410
self.b_out: np.ndarray = np.zeros((output_dim,))
344411

345412
def _input_projection(self, input_tensor: np.ndarray) -> np.ndarray:
346413
return np.tensordot(input_tensor, self.w_in, axes=([2], [0])) + self.b_in
347414

348-
def forward(self, input_tensor: np.ndarray, mask: Optional[np.ndarray] = None) -> tuple[np.ndarray, np.ndarray]:
415+
def forward(
416+
self, input_tensor: np.ndarray, mask: Optional[np.ndarray] = None
417+
) -> tuple[np.ndarray, np.ndarray]:
349418
"""
350419
Forward pass for EEG Transformer.
351420
@@ -383,7 +452,9 @@ def forward(self, input_tensor: np.ndarray, mask: Optional[np.ndarray] = None) -
383452
encoded_features = self.encoder.forward(projected_input, mask)
384453
pooled_output, attention_weights = self.pooling.forward(encoded_features, mask)
385454

386-
output_tensor = np.tensordot(pooled_output, self.w_out, axes=([1], [0])) + self.b_out
455+
output_tensor = (
456+
np.tensordot(pooled_output, self.w_out, axes=([1], [0])) + self.b_out
457+
)
387458
if self.task_type == "classification":
388459
output_tensor = _softmax(output_tensor, axis=-1)
389460

0 commit comments

Comments
 (0)