Skip to content

Commit 986cd98

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5f20061 commit 986cd98

File tree

1 file changed

+51
-15
lines changed

1 file changed

+51
-15
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,17 @@ def forward(self, time_steps: np.ndarray) -> np.ndarray:
4646
# Positionwise FeedForward
4747
# --------------------------------------------------
4848
class PositionwiseFeedForward:
49-
def __init__(self, d_model: int, hidden: int, drop_prob: float = 0.0, seed: int | None = None):
49+
def __init__(
50+
self, d_model: int, hidden: int, drop_prob: float = 0.0, seed: int | None = None
51+
):
5052
self.rng = np.random.default_rng(seed)
51-
self.w1 = self.rng.standard_normal((d_model, hidden)) * math.sqrt(2.0 / (d_model + hidden))
53+
self.w1 = self.rng.standard_normal((d_model, hidden)) * math.sqrt(
54+
2.0 / (d_model + hidden)
55+
)
5256
self.b1 = np.zeros(hidden)
53-
self.w2 = self.rng.standard_normal((hidden, d_model)) * math.sqrt(2.0 / (hidden + d_model))
57+
self.w2 = self.rng.standard_normal((hidden, d_model)) * math.sqrt(
58+
2.0 / (hidden + d_model)
59+
)
5460
self.b2 = np.zeros(d_model)
5561

5662
def forward(self, x: np.ndarray) -> np.ndarray:
@@ -95,13 +101,21 @@ def __init__(self, d_model: int, n_head: int, seed: int | None = None):
95101
self.d_k = d_model // n_head
96102
self.rng = np.random.default_rng(seed)
97103

98-
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / (2 * d_model))
104+
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
105+
2.0 / (2 * d_model)
106+
)
99107
self.b_q = np.zeros(d_model)
100-
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / (2 * d_model))
108+
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
109+
2.0 / (2 * d_model)
110+
)
101111
self.b_k = np.zeros(d_model)
102-
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / (2 * d_model))
112+
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
113+
2.0 / (2 * d_model)
114+
)
103115
self.b_v = np.zeros(d_model)
104-
self.w_out = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / (2 * d_model))
116+
self.w_out = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
117+
2.0 / (2 * d_model)
118+
)
105119
self.b_out = np.zeros(d_model)
106120

107121
self.attn = ScaledDotProductAttention()
@@ -154,7 +168,9 @@ def forward(self, x: np.ndarray) -> np.ndarray:
154168
# Transformer Encoder Layer
155169
# --------------------------------------------------
156170
class TransformerEncoderLayer:
157-
def __init__(self, d_model: int, n_head: int, hidden_dim: int, seed: int | None = None):
171+
def __init__(
172+
self, d_model: int, n_head: int, hidden_dim: int, seed: int | None = None
173+
):
158174
self.self_attn = MultiHeadAttention(d_model, n_head, seed=seed)
159175
self.ffn = PositionwiseFeedForward(d_model, hidden_dim, seed=seed)
160176
self.norm1 = LayerNorm(d_model)
@@ -171,8 +187,18 @@ def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
171187
# Transformer Encoder Stack
172188
# --------------------------------------------------
173189
class TransformerEncoder:
174-
def __init__(self, d_model: int, n_head: int, hidden_dim: int, num_layers: int, seed: int | None = None):
175-
self.layers = [TransformerEncoderLayer(d_model, n_head, hidden_dim, seed=seed) for _ in range(num_layers)]
190+
def __init__(
191+
self,
192+
d_model: int,
193+
n_head: int,
194+
hidden_dim: int,
195+
num_layers: int,
196+
seed: int | None = None,
197+
):
198+
self.layers = [
199+
TransformerEncoderLayer(d_model, n_head, hidden_dim, seed=seed)
200+
for _ in range(num_layers)
201+
]
176202

177203
def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
178204
out = x
@@ -190,7 +216,9 @@ def __init__(self, d_model: int, seed: int | None = None):
190216
self.w = self.rng.standard_normal(d_model) * math.sqrt(2.0 / d_model)
191217
self.b = 0.0
192218

193-
def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
219+
def forward(
220+
self, x: np.ndarray, mask: np.ndarray | None = None
221+
) -> tuple[np.ndarray, np.ndarray]:
194222
scores = np.tensordot(x, self.w, axes=([2], [0])) + self.b
195223
if mask is not None:
196224
scores = np.where(mask == 0, -1e9, scores)
@@ -219,18 +247,26 @@ def __init__(
219247
self.d_model = d_model
220248
self.task_type = task_type
221249

222-
self.w_in = self.rng.standard_normal((feature_dim, d_model)) * math.sqrt(2.0 / (feature_dim + d_model))
250+
self.w_in = self.rng.standard_normal((feature_dim, d_model)) * math.sqrt(
251+
2.0 / (feature_dim + d_model)
252+
)
223253
self.b_in = np.zeros(d_model)
224254
self.time2vec = Time2Vec(d_model, seed=seed)
225-
self.encoder = TransformerEncoder(d_model, n_head, hidden_dim, num_layers, seed=seed)
255+
self.encoder = TransformerEncoder(
256+
d_model, n_head, hidden_dim, num_layers, seed=seed
257+
)
226258
self.pooling = AttentionPooling(d_model, seed=seed)
227-
self.w_out = self.rng.standard_normal((d_model, output_dim)) * math.sqrt(2.0 / (d_model + output_dim))
259+
self.w_out = self.rng.standard_normal((d_model, output_dim)) * math.sqrt(
260+
2.0 / (d_model + output_dim)
261+
)
228262
self.b_out = np.zeros(output_dim)
229263

230264
def _input_proj(self, x: np.ndarray) -> np.ndarray:
231265
return np.tensordot(x, self.w_in, axes=([2], [0])) + self.b_in
232266

233-
def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
267+
def forward(
268+
self, x: np.ndarray, mask: np.ndarray | None = None
269+
) -> tuple[np.ndarray, np.ndarray]:
234270
b, t, _ = x.shape
235271
t_idx = np.arange(t, dtype=float)[None, :, None]
236272
t_idx = np.tile(t_idx, (b, 1, 1))

0 commit comments

Comments
 (0)