Skip to content

Commit ef463fe

Browse files
committed
add l2 norm for embeddings, +minor modification to make it more consistent with tf version
1 parent 9f55822 commit ef463fe

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

main.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def str2bool(s):
2323
parser.add_argument('--num_blocks', default=2, type=int)
2424
parser.add_argument('--num_epochs', default=201, type=int)
2525
parser.add_argument('--num_heads', default=1, type=int)
26-
parser.add_argument('--dropout_rate', default=0.5, type=float)
26+
parser.add_argument('--dropout_rate', default=0.2, type=float)
2727
parser.add_argument('--l2_emb', default=0.0, type=float)
2828
parser.add_argument('--device', default='cpu', type=str)
2929
parser.add_argument('--inference_only', default=False, type=str2bool)
@@ -91,9 +91,14 @@ def str2bool(s):
9191
indices = np.where(pos != 0)
9292
loss = bce_criterion(pos_logits[indices], pos_labels[indices])
9393
loss += bce_criterion(neg_logits[indices], neg_labels[indices])
94+
for param in model.item_emb.parameters(): loss += args.l2_emb * torch.norm(param)
95+
for param in model.abs_pos_K_emb.parameters(): loss += args.l2_emb * torch.norm(param)
96+
for param in model.abs_pos_V_emb.parameters(): loss += args.l2_emb * torch.norm(param)
97+
for param in model.time_matrix_K_emb.parameters(): loss += args.l2_emb * torch.norm(param)
98+
for param in model.time_matrix_V_emb.parameters(): loss += args.l2_emb * torch.norm(param)
9499
loss.backward()
95100
adam_optimizer.step()
96-
# print("loss in epoch {} iteration {}: {}".format(epoch, step, loss.item())) # expected 0.4~0.6 after init few epochs
101+
print("loss in epoch {} iteration {}: {}".format(epoch, step, loss.item())) # expected 0.4~0.6 after init few epochs
97102

98103
if epoch % 20 == 0:
99104
model.eval()

model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,20 @@ def forward(self, queries, keys, time_mask, attn_mask, time_matrix_K, time_matri
6464

6565
time_mask = time_mask.unsqueeze(-1).expand(attn_weights.shape[0], -1, attn_weights.shape[-1])
6666
attn_mask = attn_mask.unsqueeze(0).expand(attn_weights.shape[0], -1, -1)
67-
paddings = torch.ones(attn_weights.shape) * -1e23 # float('-inf')
67+
paddings = torch.ones(attn_weights.shape) * (-2**32+1) # -1e23 # float('-inf')
6868
paddings = paddings.to(self.dev)
6969
attn_weights = torch.where(time_mask, paddings, attn_weights) # True:pick padding
7070
attn_weights = torch.where(attn_mask, paddings, attn_weights) # enforcing causality
7171

7272
attn_weights = self.softmax(attn_weights) # code as below invalids pytorch backward rules
73+
# attn_weights = torch.where(time_mask, paddings, attn_weights) # weird query mask in tf impl
7374
# https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/4
7475
# attn_weights[attn_weights != attn_weights] = 0 # rm nan for -inf into softmax case
7576
attn_weights = self.dropout(attn_weights)
7677

7778
outputs = attn_weights.matmul(V_)
7879
outputs += attn_weights.matmul(abs_pos_V_)
79-
outputs += attn_weights.unsqueeze(-2).matmul(time_matrix_V_).reshape(outputs.shape)
80+
outputs += attn_weights.unsqueeze(2).matmul(time_matrix_V_).reshape(outputs.shape).squeeze(2)
8081

8182
# (num_head * N, T, C / num_head) -> (N, T, C)
8283
outputs = torch.cat(torch.split(outputs, Q.shape[0], dim=0), dim=2) # div batch_size

0 commit comments

Comments
 (0)