Skip to content

Commit 18c156e

Browse files
authored
Update real_time_encoder_transformer.py
1 parent 74714aa commit 18c156e

File tree

1 file changed

+16
-23
lines changed

1 file changed

+16
-23
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# -------------------------------
33
from __future__ import annotations
44
import math
5-
from typing import Optional, Tuple
6-
75
import numpy as np
86

97

@@ -110,26 +108,18 @@ def __init__(self, d_model: int, n_head: int, seed: int | None = None) -> None:
110108
self.n_head = n_head
111109
self.d_k = d_model // n_head
112110
self.rng = np.random.default_rng(seed)
113-
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
114-
2.0 / d_model
115-
)
116-
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
117-
2.0 / d_model
118-
)
119-
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
120-
2.0 / d_model
121-
)
122-
self.w_o = self.rng.standard_normal((d_model, d_model)) * math.sqrt(
123-
2.0 / d_model
124-
)
111+
self.w_q = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
112+
self.w_k = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
113+
self.w_v = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
114+
self.w_o = self.rng.standard_normal((d_model, d_model)) * math.sqrt(2.0 / d_model)
125115

126116
def forward(
127117
self,
128118
query: np.ndarray,
129119
key: np.ndarray,
130120
value: np.ndarray,
131121
mask: np.ndarray | None = None,
132-
) -> Tuple[np.ndarray, np.ndarray]:
122+
) -> tuple[np.ndarray, np.ndarray]:
133123
"""
134124
>>> attn = MultiHeadAttention(4, 2, seed=0)
135125
>>> x = np.ones((1, 3, 4))
@@ -140,17 +130,20 @@ def forward(
140130
(1, 2, 3, 3)
141131
"""
142132
batch_size, _seq_len, _ = query.shape
143-
Q = np.tensordot(query, self.w_q, axes=([2], [0]))
144-
K = np.tensordot(key, self.w_k, axes=([2], [0]))
145-
V = np.tensordot(value, self.w_v, axes=([2], [0]))
146-
Q = Q.reshape(batch_size, -1, self.n_head, self.d_k).transpose(0, 2, 1, 3)
147-
K = K.reshape(batch_size, -1, self.n_head, self.d_k).transpose(0, 2, 1, 3)
148-
V = V.reshape(batch_size, -1, self.n_head, self.d_k).transpose(0, 2, 1, 3)
149-
scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / math.sqrt(self.d_k)
133+
q = np.tensordot(query, self.w_q, axes=([2], [0]))
134+
k = np.tensordot(key, self.w_k, axes=([2], [0]))
135+
v = np.tensordot(value, self.w_v, axes=([2], [0]))
136+
137+
q = q.reshape(batch_size, -1, self.n_head, self.d_k).transpose(0, 2, 1, 3)
138+
k = k.reshape(batch_size, -1, self.n_head, self.d_k).transpose(0, 2, 1, 3)
139+
v = v.reshape(batch_size, -1, self.n_head, self.d_k).transpose(0, 2, 1, 3)
140+
141+
scores = np.matmul(q, k.transpose(0, 1, 3, 2)) / math.sqrt(self.d_k)
150142
if mask is not None:
151143
scores = np.where(mask[:, None, None, :] == 0, -1e9, scores)
144+
152145
attn_weights = _softmax(scores, axis=-1)
153-
out = np.matmul(attn_weights, V)
146+
out = np.matmul(attn_weights, v)
154147
out = out.transpose(0, 2, 1, 3).reshape(batch_size, -1, self.d_model)
155148
out = np.tensordot(out, self.w_o, axes=([2], [0]))
156149
return out, attn_weights

0 commit comments

Comments
 (0)