@@ -23,7 +23,7 @@ def str2bool(s):
2323parser .add_argument ('--num_blocks' , default = 2 , type = int )
2424parser .add_argument ('--num_epochs' , default = 201 , type = int )
2525parser .add_argument ('--num_heads' , default = 1 , type = int )
26- parser .add_argument ('--dropout_rate' , default = 0.5 , type = float )
26+ parser .add_argument ('--dropout_rate' , default = 0.2 , type = float )
2727parser .add_argument ('--l2_emb' , default = 0.0 , type = float )
2828parser .add_argument ('--device' , default = 'cpu' , type = str )
2929parser .add_argument ('--inference_only' , default = False , type = str2bool )
@@ -91,9 +91,14 @@ def str2bool(s):
9191 indices = np .where (pos != 0 )
9292 loss = bce_criterion (pos_logits [indices ], pos_labels [indices ])
9393 loss += bce_criterion (neg_logits [indices ], neg_labels [indices ])
94+ for param in model .item_emb .parameters (): loss += args .l2_emb * torch .norm (param )
95+ for param in model .abs_pos_K_emb .parameters (): loss += args .l2_emb * torch .norm (param )
96+ for param in model .abs_pos_V_emb .parameters (): loss += args .l2_emb * torch .norm (param )
97+ for param in model .time_matrix_K_emb .parameters (): loss += args .l2_emb * torch .norm (param )
98+ for param in model .time_matrix_V_emb .parameters (): loss += args .l2_emb * torch .norm (param )
9499 loss .backward ()
95100 adam_optimizer .step ()
96- # print("loss in epoch {} iteration {}: {}".format(epoch, step, loss.item())) # expected 0.4~0.6 after init few epochs
101+ print ("loss in epoch {} iteration {}: {}" .format (epoch , step , loss .item ())) # expected 0.4~0.6 after init few epochs
97102
98103 if epoch % 20 == 0 :
99104 model .eval ()
0 commit comments