I met a programming issue about class a function. it seems like I can not class it correctly. Can you please point out the issue? THANK YOU !
JavaScript
x
26
26
1
class NTXentLoss(nn.Module):
2
def __init__(self, temp=0.5):
3
super(NTXentLoss, self).__init__()
4
self.temp = temp
5
6
def forward(self, zi, zj):
7
batch_size = zi.shape[0]
8
z_proj = torch.cat((zi, zj), dim=0)
9
cos_sim = torch.nn.CosineSimilarity(dim=-1)
10
sim_mat = cos_sim(z_proj.unsqueeze(1), z_proj.unsqueeze(0))
11
sim_mat_scaled = torch.exp(sim_mat/self.temp)
12
r_diag = torch.diag(sim_mat_scaled, batch_size)
13
l_diag = torch.diag(sim_mat_scaled, -batch_size)
14
pos = torch.cat([r_diag, l_diag])
15
diag_mat = torch.exp(torch.ones(batch_size * 2)/self.temp).cuda()
16
logit = -torch.log(pos/(sim_mat_scaled.sum(1) - diag_mat))
17
loss = logit.mean()
18
return loss
19
20
sent_A = l2norm(recov_A, dim=1)
21
sent_emb_A = l2norm(imgs_A, dim=1)
22
sent_B = l2norm(recov_B, dim=1)
23
sent_emb_B = l2norm(imgs_B, dim=1)
24
25
G_cons = NTXentLoss(sent_A,sent_emb_A) + NTXentLoss(sent_B,sent_emb_B)
26
What’s wrong with this, I just gave two positional arguments? or
JavaScript
1
2
1
G_cons = NTXentLoss.forward(sent_A,sent_emb_A) + NTXentLoss.forward(sent_B,sent_emb_B)
2
Advertisement
Answer
You need to first initiate a NTXentLoss object before you can call it. For instance:
JavaScript
1
3
1
ntx = NTXentLoss()
2
G_cons = ntx(sent_A,sent_emb_A) + ntx(sent_B,sent_emb_B)
3