Skip to content

Commit 007dcf1

Browse files
authored
Update real_time_encoder_transformer.py
1 parent 80aff7a commit 007dcf1

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ def forward(self, time_steps: np.ndarray) -> np.ndarray:
3838
>>> out.shape
3939
(1, 3, 4)
4040
"""
41+
42+
4143
linear = self.w0 * time_steps + self.b0
42-
periodic = np.sin(self.w * time_steps + self.b)
44+
periodic = np.sin(time_steps * self.w[:, None, :] + self.b[:, None, :])
4345
return np.concatenate([linear, periodic], axis=-1)
4446

4547

0 commit comments

Comments
 (0)