Skip to content

Commit d3a8f47

Browse files
authored
Update real_time_encoder_transformer.py
1 parent 0974fee commit d3a8f47

File tree

1 file changed

+40
-45
lines changed

1 file changed

+40
-45
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# imports
21
import math
32

43
import torch
@@ -16,7 +15,6 @@ class Time2Vec(nn.Module):
1615
>>> output.shape
1716
torch.Size([1, 3, 4])
1817
"""
19-
2018
def __init__(self, d_model: int) -> None:
2119
super().__init__()
2220
self.w0 = nn.Parameter(torch.randn(1, 1))
@@ -41,7 +39,6 @@ class PositionwiseFeedForward(nn.Module):
4139
>>> out.shape
4240
torch.Size([4, 10, 8])
4341
"""
44-
4542
def __init__(self, d_model: int, hidden: int, drop_prob: float = 0.1) -> None:
4643
super().__init__()
4744
self.fc1 = nn.Linear(d_model, hidden)
@@ -62,29 +59,32 @@ class ScaleDotProductAttention(nn.Module):
6259
6360
>>> import torch
6461
>>> attn = ScaleDotProductAttention()
65-
>>> q = torch.rand(2, 8, 10, 16)
66-
>>> k = torch.rand(2, 8, 10, 16)
67-
>>> v = torch.rand(2, 8, 10, 16)
68-
>>> ctx, attn_w = attn.forward(q, k, v)
62+
>>> query_tensor = torch.rand(2, 8, 10, 16)
63+
>>> key_tensor = torch.rand(2, 8, 10, 16)
64+
>>> value_tensor = torch.rand(2, 8, 10, 16)
65+
>>> ctx, attn_w = attn.forward(query_tensor, key_tensor, value_tensor)
6966
>>> ctx.shape
7067
torch.Size([2, 8, 10, 16])
7168
"""
72-
7369
def __init__(self) -> None:
7470
super().__init__()
7571
self.softmax = nn.Softmax(dim=-1)
7672

7773
def forward(
78-
self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None
74+
self,
75+
query_tensor: Tensor,
76+
key_tensor: Tensor,
77+
value_tensor: Tensor,
78+
mask: Tensor = None,
7979
) -> tuple[Tensor, Tensor]:
80-
_, _, _, d_k = k.size()
81-
scores = (q @ k.transpose(2, 3)) / math.sqrt(d_k)
80+
_, _, _, d_k = key_tensor.size()
81+
scores = (query_tensor @ key_tensor.transpose(2, 3)) / math.sqrt(d_k)
8282

8383
if mask is not None:
8484
scores = scores.masked_fill(mask == 0, -1e9)
8585

8686
attn = self.softmax(scores)
87-
context = attn @ v
87+
context = attn @ value_tensor
8888
return context, attn
8989

9090

@@ -94,12 +94,11 @@ class MultiHeadAttention(nn.Module):
9494
9595
>>> import torch
9696
>>> attn = MultiHeadAttention(16, 4)
97-
>>> q = torch.rand(2, 10, 16)
98-
>>> out = attn.forward(q, q, q)
97+
>>> query_tensor = torch.rand(2, 10, 16)
98+
>>> out = attn.forward(query_tensor, query_tensor, query_tensor)
9999
>>> out.shape
100100
torch.Size([2, 10, 16])
101101
"""
102-
103102
def __init__(self, d_model: int, n_head: int) -> None:
104103
super().__init__()
105104
self.n_head = n_head
@@ -109,22 +108,34 @@ def __init__(self, d_model: int, n_head: int) -> None:
109108
self.w_v = nn.Linear(d_model, d_model)
110109
self.w_out = nn.Linear(d_model, d_model)
111110

112-
def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None) -> Tensor:
113-
q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
114-
q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)
111+
def forward(
112+
self,
113+
query_tensor: Tensor,
114+
key_tensor: Tensor,
115+
value_tensor: Tensor,
116+
mask: Tensor = None,
117+
) -> Tensor:
118+
query_tensor, key_tensor, value_tensor = (
119+
self.w_q(query_tensor),
120+
self.w_k(key_tensor),
121+
self.w_v(value_tensor),
122+
)
123+
query_tensor = self.split_heads(query_tensor)
124+
key_tensor = self.split_heads(key_tensor)
125+
value_tensor = self.split_heads(value_tensor)
115126

116-
context, _ = self.attn(q, k, v, mask)
127+
context, _ = self.attn(query_tensor, key_tensor, value_tensor, mask)
117128
out = self.w_out(self.concat_heads(context))
118129
return out
119130

120-
def split_heads(self, x: Tensor) -> Tensor:
121-
batch, seq_len, d_model = x.size()
131+
def split_heads(self, input_tensor: Tensor) -> Tensor:
132+
batch, seq_len, d_model = input_tensor.size()
122133
d_k = d_model // self.n_head
123-
return x.view(batch, seq_len, self.n_head, d_k).transpose(1, 2)
134+
return input_tensor.view(batch, seq_len, self.n_head, d_k).transpose(1, 2)
124135

125-
def concat_heads(self, x: Tensor) -> Tensor:
126-
batch, n_head, seq_len, d_k = x.size()
127-
return x.transpose(1, 2).contiguous().view(batch, seq_len, n_head * d_k)
136+
def concat_heads(self, input_tensor: Tensor) -> Tensor:
137+
batch, n_head, seq_len, d_k = input_tensor.size()
138+
return input_tensor.transpose(1, 2).contiguous().view(batch, seq_len, n_head * d_k)
128139

129140

130141
class LayerNorm(nn.Module):
@@ -138,7 +149,6 @@ class LayerNorm(nn.Module):
138149
>>> out.shape
139150
torch.Size([4, 10, 8])
140151
"""
141-
142152
def __init__(self, d_model: int, eps: float = 1e-12) -> None:
143153
super().__init__()
144154
self.gamma = nn.Parameter(torch.ones(d_model))
@@ -148,9 +158,7 @@ def __init__(self, d_model: int, eps: float = 1e-12) -> None:
148158
def forward(self, input_tensor: Tensor) -> Tensor:
149159
mean = input_tensor.mean(-1, keepdim=True)
150160
var = input_tensor.var(-1, unbiased=False, keepdim=True)
151-
return (
152-
self.gamma * (input_tensor - mean) / torch.sqrt(var + self.eps) + self.beta
153-
)
161+
return self.gamma * (input_tensor - mean) / torch.sqrt(var + self.eps) + self.beta
154162

155163

156164
class TransformerEncoderLayer(nn.Module):
@@ -164,7 +172,6 @@ class TransformerEncoderLayer(nn.Module):
164172
>>> out.shape
165173
torch.Size([4, 10, 8])
166174
"""
167-
168175
def __init__(
169176
self,
170177
d_model: int,
@@ -198,7 +205,6 @@ class TransformerEncoder(nn.Module):
198205
>>> out.shape
199206
torch.Size([4, 10, 8])
200207
"""
201-
202208
def __init__(
203209
self,
204210
d_model: int,
@@ -235,14 +241,11 @@ class AttentionPooling(nn.Module):
235241
>>> weights.shape
236242
torch.Size([4, 10])
237243
"""
238-
239244
def __init__(self, d_model: int) -> None:
240245
super().__init__()
241246
self.attn_score = nn.Linear(d_model, 1)
242247

243-
def forward(
244-
self, input_tensor: Tensor, mask: Tensor = None
245-
) -> tuple[Tensor, Tensor]:
248+
def forward(self, input_tensor: Tensor, mask: Tensor = None) -> tuple[Tensor, Tensor]:
246249
attn_weights = torch.softmax(self.attn_score(input_tensor).squeeze(-1), dim=-1)
247250

248251
if mask is not None:
@@ -264,7 +267,6 @@ class EEGTransformer(nn.Module):
264267
>>> out.shape
265268
torch.Size([2, 1])
266269
"""
267-
268270
def __init__(
269271
self,
270272
feature_dim: int,
@@ -286,16 +288,9 @@ def __init__(
286288
self.pooling = AttentionPooling(d_model)
287289
self.output_layer = nn.Linear(d_model, output_dim)
288290

289-
def forward(
290-
self, input_tensor: Tensor, mask: Tensor = None
291-
) -> tuple[Tensor, Tensor]:
291+
def forward(self, input_tensor: Tensor, mask: Tensor = None) -> tuple[Tensor, Tensor]:
292292
b, t, _ = input_tensor.size()
293-
t_idx = (
294-
torch.arange(t, device=input_tensor.device)
295-
.view(1, t, 1)
296-
.expand(b, t, 1)
297-
.float()
298-
)
293+
t_idx = torch.arange(t, device=input_tensor.device).view(1, t, 1).expand(b, t, 1).float()
299294
time_emb = self.time2vec(t_idx)
300295
x = self.input_proj(input_tensor) + time_emb
301296
x = self.encoder(x, mask)

0 commit comments

Comments
 (0)