Skip to content

Commit c96d440

Browse files
authored
Update real_time_encoder_transformer.py
1 parent 23c5117 commit c96d440

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

neural_network/real_time_encoder_transformer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# imports
2-
import torch
3-
import torch.nn as nn
42
import math
3+
import torch
4+
from torch import nn
55

66

77
# Time2Vec layer for positional encoding of real-time data like EEG
88
class Time2Vec(nn.Module):
9-
# Encodes time steps into a continuous embedding space so to help the transformer learn temporal dependencies.
9+
# Encodes time steps into a continuous embedding space
1010
def __init__(self, d_model):
1111
super().__init__()
1212
self.w0 = nn.Parameter(torch.randn(1, 1))
@@ -174,8 +174,12 @@ def __init__(
174174

175175
# Transformer encoder for sequence modeling
176176
self.encoder = TransformerEncoder(
177-
d_model, n_head, hidden_dim, num_layers, drop_prob
178-
)
177+
d_model,
178+
n_head,
179+
hidden_dim,
180+
num_layers,
181+
drop_prob
182+
)
179183

180184
# Attention pooling to summarize time dimension
181185
self.pooling = AttentionPooling(d_model)

0 commit comments

Comments
 (0)