@@ -23,7 +23,7 @@ def forward(self, inputs):
2323
2424class TimeAwareMultiHeadAttention (torch .nn .Module ):
2525 # required homebrewed mha layer for Ti/SASRec experiments
26- def __init__ (self , hidden_size , head_num , dropout_rate ):
26+ def __init__ (self , hidden_size , head_num , dropout_rate , dev ):
2727 super (TimeAwareMultiHeadAttention , self ).__init__ ()
2828 self .Q_w = torch .nn .Linear (hidden_size , hidden_size )
2929 self .K_w = torch .nn .Linear (hidden_size , hidden_size )
@@ -36,6 +36,7 @@ def __init__(self, hidden_size, head_num, dropout_rate):
3636 self .head_num = head_num
3737 self .head_size = hidden_size // head_num
3838 self .dropout_rate = dropout_rate
39+ self .dev = dev
3940
4041 def forward (self , queries , keys , time_mask , attn_mask , time_matrix_K , time_matrix_V , abs_pos_K , abs_pos_V ):
4142 Q , K , V = self .Q_w (queries ), self .K_w (keys ), self .V_w (keys )
@@ -63,7 +64,8 @@ def forward(self, queries, keys, time_mask, attn_mask, time_matrix_K, time_matri
6364
6465 time_mask = time_mask .unsqueeze (- 1 ).expand (attn_weights .shape [0 ], - 1 , attn_weights .shape [- 1 ])
6566 attn_mask = attn_mask .unsqueeze (0 ).expand (attn_weights .shape [0 ], - 1 , - 1 )
66- paddings = torch .ones (attn_weights .shape ) * FLOAT_MIN # float('-inf')
67+ paddings = torch .ones (attn_weights .shape ) * - 1e23 # float('-inf')
68+ paddings = paddings .to (self .dev )
6769 attn_weights = torch .where (time_mask , paddings , attn_weights ) # True:pick padding
6870 attn_weights = torch .where (attn_mask , paddings , attn_weights ) # enforcing causality
6971
@@ -119,7 +121,8 @@ def __init__(self, user_num, item_num, time_num, args):
119121
120122 new_attn_layer = TimeAwareMultiHeadAttention (args .hidden_units ,
121123 args .num_heads ,
122- args .dropout_rate )
124+ args .dropout_rate ,
125+ args .device )
123126 self .attention_layers .append (new_attn_layer )
124127
125128 new_fwd_layernorm = torch .nn .LayerNorm (args .hidden_units , eps = 1e-8 )
0 commit comments