I am trying to compute cross_entropy loss manually in Pytorch for an encoder-decoder model.
I used the code posted here to compute it: Cross Entropy in PyTorch
I updated the code to discard padded tokens (-100). The final code is this:
class compute_crossentropyloss_manual: """ y0 is the vector with shape (batch_size,C) x shape is the same (batch_size), whose entries are integers from 0 to C-1 """ def __init__(self, ignore_index=-100) -> None: self.ignore_index=ignore_index def __call__(self, y0, x): loss = 0. n_batch, n_class = y0.shape # print(n_class) for y1, x1 in zip(y0, x): class_index = int(x1.item()) if class_index == self.ignore_index: # <------ I added this if-statement continue loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum())) loss = - loss/n_batch return loss
To verify that it works fine, I tested it on a text generation task, and I computed the loss using pytorch.nn implementation and using this code.
The loss values are not identical:
using nn.CrossEntropyLoss
:
Using the code from the link above:
Am I missing something?
I tried to get the source code of nn.CrossEntropyLoss
but I wasn’t able. In this link nn/functional.py at line 2955, you will see that the function points to another cross_entropy loss called torch._C._nn.cross_entropy_loss
; I can’t find this function in the repo.
Edit:
I noticed that the differences appear only when I have -100
tokens in the gold.
Demo example:
y = torch.randint(1, 50, (100, 50), dtype=float) x = torch.randint(1, 50, (100,)) x[40:] = -100 print(criterion(y, x).item()) print(criterion2(y, x).item()) > 25.55788695847976 > 10.223154783391905
and when we don’t have -100
:
x[40:] = 30 # any positive number print(criterion(y, x).item()) print(criterion2(y, x).item()) > 24.684453267596453 > 24.684453267596453
Advertisement
Answer
I solved the problem by updating the code. I discarded before the -100
tokens (the if-statement above), but I forgot to reduce the hidden_state size (which is called n_batch
in the code above). After doing that, the loss numbers are identical to the nn.CrossEntropyLoss
values. The final code:
class CrossEntropyLossManual: """ y0 is the vector with shape (batch_size,C) x shape is the same (batch_size), whose entries are integers from 0 to C-1 """ def __init__(self, ignore_index=-100) -> None: self.ignore_index=ignore_index def __call__(self, y0, x): loss = 0. n_batch, n_class = y0.shape # print(n_class) for y1, x1 in zip(y0, x): class_index = int(x1.item()) if class_index == self.ignore_index: n_batch -= 1 continue loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum())) loss = - loss/n_batch return loss