Skip to content

Commit e33202b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 18c156e commit e33202b

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,18 @@ def __init__(self, d_model: int, n_head: int, seed: int | None = None) -> None:
108108
self.n_head = n_head
109109
self.d_k = d_model // n_head
110110
self.rng = np.random.default_rng(seed)
111-
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
112-
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
113-
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
114-
self.w_o = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
111+
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
112+
2.0 / d_model
113+
)
114+
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
115+
2.0 / d_model
116+
)
117+
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
118+
2.0 / d_model
119+
)
120+
self.w_o = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
121+
2.0 / d_model
122+
)
115123

116124
def forward(
117125
self,

0 commit comments

Comments
 (0)