-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
loss_fun():
pos = torch.exp(torch.div(torch.bmm(q.view(N, 1, C), k.view(N, C, 1)).view(N, 1), self.args.moco_t))#N*1
neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N, C), torch.t(self.queue)), self.args.moco_t)), dim=1)#N
#denominator is sum over pos and neg
denominator = pos + neg
by your calculation ,pos.size()=(N,1),neg.size()=N,and denominator.size()=(N,N)!!!! this is not right. therefore, when run 'torch.mean(-torch.log(torch.div(pos, denominator)))' , for one sample, it uses not only its pos score and negative score. anyway ,it doesn't match the contrastive loss
Metadata
Metadata
Assignees
Labels
No labels