Skip to content

Commit 33cf40a

Browse files
authored
Update real_time_encoder_transformer.py
1 parent e33202b commit 33cf40a

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,10 @@ def __init__(self, d_model: int, n_head: int, seed: int | None = None) -> None:
108108
self.n_head = n_head
109109
self.d_k = d_model // n_head
110110
self.rng = np.random.default_rng(seed)
111-
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
112-
2.0 / d_model
113-
)
114-
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
115-
2.0 / d_model
116-
)
117-
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
118-
2.0 / d_model
119-
)
120-
self.w_o = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
121-
2.0 / d_model
122-
)
111+
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
112+
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
113+
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
114+
self.w_o = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
123115

124116
def forward(
125117
self,

0 commit comments

Comments
 (0)