I met a programming issue about class a function. it seems like I can not class it correctly. Can you please point out the issue? THANK YOU !
class NTXentLoss(nn.Module): def __init__(self, temp=0.5): super(NTXentLoss, self).__init__() self.temp = temp def forward(self, zi, zj): batch_size = zi.shape[0] z_proj = torch.cat((zi, zj), dim=0) cos_sim = torch.nn.CosineSimilarity(dim=-1) sim_mat = cos_sim(z_proj.unsqueeze(1), z_proj.unsqueeze(0)) sim_mat_scaled = torch.exp(sim_mat/self.temp) r_diag = torch.diag(sim_mat_scaled, batch_size) l_diag = torch.diag(sim_mat_scaled, -batch_size) pos = torch.cat([r_diag, l_diag]) diag_mat = torch.exp(torch.ones(batch_size * 2)/self.temp).cuda() logit = -torch.log(pos/(sim_mat_scaled.sum(1) - diag_mat)) loss = logit.mean() return loss sent_A = l2norm(recov_A, dim=1) sent_emb_A = l2norm(imgs_A, dim=1) sent_B = l2norm(recov_B, dim=1) sent_emb_B = l2norm(imgs_B, dim=1) G_cons = NTXentLoss(sent_A,sent_emb_A) + NTXentLoss(sent_B,sent_emb_B)
What’s wrong with this, I just gave two positional arguments? or
G_cons = NTXentLoss.forward(sent_A,sent_emb_A) + NTXentLoss.forward(sent_B,sent_emb_B)
Advertisement
Answer
You need to first initiate a NTXentLoss object before you can call it. For instance:
ntx = NTXentLoss() G_cons = ntx(sent_A,sent_emb_A) + ntx(sent_B,sent_emb_B)