Skip to content
Advertisement

Python : TypeError: __init__() takes from 1 to 2 positional arguments but 3 were given

I met a programming issue about class a function. it seems like I can not class it correctly. Can you please point out the issue? THANK YOU !

class NTXentLoss(nn.Module):
    def __init__(self, temp=0.5):
        super(NTXentLoss, self).__init__()
        self.temp = temp
    
    def forward(self, zi, zj):
        batch_size = zi.shape[0]
        z_proj = torch.cat((zi, zj), dim=0)
        cos_sim = torch.nn.CosineSimilarity(dim=-1)
        sim_mat = cos_sim(z_proj.unsqueeze(1), z_proj.unsqueeze(0))
        sim_mat_scaled = torch.exp(sim_mat/self.temp)
        r_diag = torch.diag(sim_mat_scaled, batch_size)
        l_diag = torch.diag(sim_mat_scaled, -batch_size)
        pos = torch.cat([r_diag, l_diag])
        diag_mat = torch.exp(torch.ones(batch_size * 2)/self.temp).cuda()
        logit = -torch.log(pos/(sim_mat_scaled.sum(1) - diag_mat))
        loss = logit.mean()
        return loss

        sent_A = l2norm(recov_A, dim=1)
        sent_emb_A = l2norm(imgs_A, dim=1)
        sent_B = l2norm(recov_B, dim=1)
        sent_emb_B = l2norm(imgs_B, dim=1)

G_cons = NTXentLoss(sent_A,sent_emb_A) + NTXentLoss(sent_B,sent_emb_B)

What’s wrong with this, I just gave two positional arguments? or

G_cons = NTXentLoss.forward(sent_A,sent_emb_A) + NTXentLoss.forward(sent_B,sent_emb_B)

Advertisement

Answer

You need to first initiate a NTXentLoss object before you can call it. For instance:

ntx = NTXentLoss()
G_cons = ntx(sent_A,sent_emb_A) + ntx(sent_B,sent_emb_B)
User contributions licensed under: CC BY-SA
7 People found this is helpful
Advertisement