@@ -46,11 +46,17 @@ def forward(self, time_steps: np.ndarray) -> np.ndarray:
4646# Positionwise FeedForward
4747# --------------------------------------------------
4848class PositionwiseFeedForward :
49- def __init__ (self , d_model : int , hidden : int , drop_prob : float = 0.0 , seed : int | None = None ):
49+ def __init__ (
50+ self , d_model : int , hidden : int , drop_prob : float = 0.0 , seed : int | None = None
51+ ):
5052 self .rng = np .random .default_rng (seed )
51- self .w1 = self .rng .standard_normal ((d_model , hidden )) * math .sqrt (2.0 / (d_model + hidden ))
53+ self .w1 = self .rng .standard_normal ((d_model , hidden )) * math .sqrt (
54+ 2.0 / (d_model + hidden )
55+ )
5256 self .b1 = np .zeros (hidden )
53- self .w2 = self .rng .standard_normal ((hidden , d_model )) * math .sqrt (2.0 / (hidden + d_model ))
57+ self .w2 = self .rng .standard_normal ((hidden , d_model )) * math .sqrt (
58+ 2.0 / (hidden + d_model )
59+ )
5460 self .b2 = np .zeros (d_model )
5561
5662 def forward (self , x : np .ndarray ) -> np .ndarray :
@@ -95,13 +101,21 @@ def __init__(self, d_model: int, n_head: int, seed: int | None = None):
95101 self .d_k = d_model // n_head
96102 self .rng = np .random .default_rng (seed )
97103
98- self .w_q = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / (2 * d_model ))
104+ self .w_q = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
105+ 2.0 / (2 * d_model )
106+ )
99107 self .b_q = np .zeros (d_model )
100- self .w_k = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / (2 * d_model ))
108+ self .w_k = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
109+ 2.0 / (2 * d_model )
110+ )
101111 self .b_k = np .zeros (d_model )
102- self .w_v = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / (2 * d_model ))
112+ self .w_v = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
113+ 2.0 / (2 * d_model )
114+ )
103115 self .b_v = np .zeros (d_model )
104- self .w_out = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (2.0 / (2 * d_model ))
116+ self .w_out = self .rng .standard_normal ((d_model , d_model )) * math .sqrt (
117+ 2.0 / (2 * d_model )
118+ )
105119 self .b_out = np .zeros (d_model )
106120
107121 self .attn = ScaledDotProductAttention ()
@@ -154,7 +168,9 @@ def forward(self, x: np.ndarray) -> np.ndarray:
154168# Transformer Encoder Layer
155169# --------------------------------------------------
156170class TransformerEncoderLayer :
157- def __init__ (self , d_model : int , n_head : int , hidden_dim : int , seed : int | None = None ):
171+ def __init__ (
172+ self , d_model : int , n_head : int , hidden_dim : int , seed : int | None = None
173+ ):
158174 self .self_attn = MultiHeadAttention (d_model , n_head , seed = seed )
159175 self .ffn = PositionwiseFeedForward (d_model , hidden_dim , seed = seed )
160176 self .norm1 = LayerNorm (d_model )
@@ -171,8 +187,18 @@ def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
171187# Transformer Encoder Stack
172188# --------------------------------------------------
173189class TransformerEncoder :
174- def __init__ (self , d_model : int , n_head : int , hidden_dim : int , num_layers : int , seed : int | None = None ):
175- self .layers = [TransformerEncoderLayer (d_model , n_head , hidden_dim , seed = seed ) for _ in range (num_layers )]
190+ def __init__ (
191+ self ,
192+ d_model : int ,
193+ n_head : int ,
194+ hidden_dim : int ,
195+ num_layers : int ,
196+ seed : int | None = None ,
197+ ):
198+ self .layers = [
199+ TransformerEncoderLayer (d_model , n_head , hidden_dim , seed = seed )
200+ for _ in range (num_layers )
201+ ]
176202
177203 def forward (self , x : np .ndarray , mask : np .ndarray | None = None ) -> np .ndarray :
178204 out = x
@@ -190,7 +216,9 @@ def __init__(self, d_model: int, seed: int | None = None):
190216 self .w = self .rng .standard_normal (d_model ) * math .sqrt (2.0 / d_model )
191217 self .b = 0.0
192218
193- def forward (self , x : np .ndarray , mask : np .ndarray | None = None ) -> tuple [np .ndarray , np .ndarray ]:
219+ def forward (
220+ self , x : np .ndarray , mask : np .ndarray | None = None
221+ ) -> tuple [np .ndarray , np .ndarray ]:
194222 scores = np .tensordot (x , self .w , axes = ([2 ], [0 ])) + self .b
195223 if mask is not None :
196224 scores = np .where (mask == 0 , - 1e9 , scores )
@@ -219,18 +247,26 @@ def __init__(
219247 self .d_model = d_model
220248 self .task_type = task_type
221249
222- self .w_in = self .rng .standard_normal ((feature_dim , d_model )) * math .sqrt (2.0 / (feature_dim + d_model ))
250+ self .w_in = self .rng .standard_normal ((feature_dim , d_model )) * math .sqrt (
251+ 2.0 / (feature_dim + d_model )
252+ )
223253 self .b_in = np .zeros (d_model )
224254 self .time2vec = Time2Vec (d_model , seed = seed )
225- self .encoder = TransformerEncoder (d_model , n_head , hidden_dim , num_layers , seed = seed )
255+ self .encoder = TransformerEncoder (
256+ d_model , n_head , hidden_dim , num_layers , seed = seed
257+ )
226258 self .pooling = AttentionPooling (d_model , seed = seed )
227- self .w_out = self .rng .standard_normal ((d_model , output_dim )) * math .sqrt (2.0 / (d_model + output_dim ))
259+ self .w_out = self .rng .standard_normal ((d_model , output_dim )) * math .sqrt (
260+ 2.0 / (d_model + output_dim )
261+ )
228262 self .b_out = np .zeros (output_dim )
229263
230264 def _input_proj (self , x : np .ndarray ) -> np .ndarray :
231265 return np .tensordot (x , self .w_in , axes = ([2 ], [0 ])) + self .b_in
232266
233- def forward (self , x : np .ndarray , mask : np .ndarray | None = None ) -> tuple [np .ndarray , np .ndarray ]:
267+ def forward (
268+ self , x : np .ndarray , mask : np .ndarray | None = None
269+ ) -> tuple [np .ndarray , np .ndarray ]:
234270 b , t , _ = x .shape
235271 t_idx = np .arange (t , dtype = float )[None , :, None ]
236272 t_idx = np .tile (t_idx , (b , 1 , 1 ))
0 commit comments