Skip to content

Commit 74714aa

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e6e2092 commit 74714aa

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ def forward(self, input_tensor: np.ndarray) -> np.ndarray:
7474
class PositionwiseFeedForward:
7575
def __init__(self, d_model: int, hidden: int, seed: int | None = None) -> None:
7676
self.rng = np.random.default_rng(seed)
77-
self.linear1_w = self.rng.standard_normal((d_model, hidden)) * \
78-
math.sqrt(2.0 / (d_model + hidden))
77+
self.linear1_w = self.rng.standard_normal((d_model, hidden)) * math.sqrt(
78+
2.0 / (d_model + hidden)
79+
)
7980
self.linear1_b = np.zeros((hidden,))
80-
self.linear2_w = self.rng.standard_normal((hidden, d_model)) * \
81-
math.sqrt(2.0 / (hidden + d_model))
81+
self.linear2_w = self.rng.standard_normal((hidden, d_model)) * math.sqrt(
82+
2.0 / (hidden + d_model)
83+
)
8284
self.linear2_b = np.zeros((d_model,))
8385

8486
def forward(self, x_tensor: np.ndarray) -> np.ndarray:
@@ -89,7 +91,9 @@ def forward(self, x_tensor: np.ndarray) -> np.ndarray:
8991
>>> out.shape
9092
(1, 3, 4)
9193
"""
92-
hidden = np.tensordot(x_tensor, self.linear1_w, axes=([2], [0])) + self.linear1_b
94+
hidden = (
95+
np.tensordot(x_tensor, self.linear1_w, axes=([2], [0])) + self.linear1_b
96+
)
9397
hidden = np.maximum(0, hidden) # ReLU
9498
out = np.tensordot(hidden, self.linear2_w, axes=([2], [0])) + self.linear2_b
9599
return out
@@ -106,10 +110,18 @@ def __init__(self, d_model: int, n_head: int, seed: int | None = None) -> None:
106110
self.n_head = n_head
107111
self.d_k = d_model // n_head
108112
self.rng = np.random.default_rng(seed)
109-
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
110-
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
111-
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
112-
self.w_o = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
113+
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
114+
2.0 / d_model
115+
)
116+
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
117+
2.0 / d_model
118+
)
119+
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
120+
2.0 / d_model
121+
)
122+
self.w_o = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
123+
2.0 / d_model
124+
)
113125

114126
def forward(
115127
self,
@@ -148,13 +160,17 @@ def forward(
148160
# 🔹 TransformerEncoderLayer
149161
# -------------------------------
150162
class TransformerEncoderLayer:
151-
def __init__(self, d_model: int, n_head: int, hidden_dim: int, seed: int | None = None) -> None:
163+
def __init__(
164+
self, d_model: int, n_head: int, hidden_dim: int, seed: int | None = None
165+
) -> None:
152166
self.self_attn = MultiHeadAttention(d_model, n_head, seed)
153167
self.norm1 = LayerNorm(d_model)
154168
self.ff = PositionwiseFeedForward(d_model, hidden_dim, seed)
155169
self.norm2 = LayerNorm(d_model)
156170

157-
def forward(self, x_tensor: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
171+
def forward(
172+
self, x_tensor: np.ndarray, mask: np.ndarray | None = None
173+
) -> np.ndarray:
158174
"""
159175
>>> layer = TransformerEncoderLayer(4, 2, 8, seed=0)
160176
>>> x = np.ones((1, 3, 4))
@@ -179,14 +195,16 @@ def __init__(
179195
n_head: int,
180196
hidden_dim: int,
181197
num_layers: int,
182-
seed: int | None = None
198+
seed: int | None = None,
183199
) -> None:
184200
self.layers = [
185201
TransformerEncoderLayer(d_model, n_head, hidden_dim, seed)
186202
for _ in range(num_layers)
187203
]
188204

189-
def forward(self, x_tensor: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
205+
def forward(
206+
self, x_tensor: np.ndarray, mask: np.ndarray | None = None
207+
) -> np.ndarray:
190208
"""
191209
>>> encoder = TransformerEncoder(4, 2, 8, 2, seed=0)
192210
>>> x = np.ones((1, 3, 4))
@@ -231,17 +249,17 @@ def __init__(
231249
num_layers: int,
232250
output_dim: int = 1,
233251
task_type: str = "regression",
234-
seed: int | None = None
252+
seed: int | None = None,
235253
) -> None:
236254
self.time2vec = Time2Vec(d_model, seed)
237255
self.encoder = TransformerEncoder(d_model, n_head, hidden_dim, num_layers, seed)
238256
self.pooling = AttentionPooling(d_model, seed)
239257
self.output_dim = output_dim
240258
self.task_type = task_type
241259
self.rng = np.random.default_rng(seed)
242-
self.w_out = self.rng.standard_normal(
243-
(d_model, output_dim)
244-
) * math.sqrt(2.0 / (d_model + output_dim))
260+
self.w_out = self.rng.standard_normal((d_model, output_dim)) * math.sqrt(
261+
2.0 / (d_model + output_dim)
262+
)
245263
self.b_out = np.zeros((output_dim,))
246264

247265
def forward(self, eeg_data: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)