Skip to content
Advertisement

pytorch custom loss function nn.CrossEntropyLoss

After studying autograd, I tried to make loss function myself. And here are my loss

def myCEE(outputs,targets):
    exp=torch.exp(outputs)
    A=torch.log(torch.sum(exp,dim=1))
    
    hadamard=F.one_hot(targets, num_classes=10).float()*outputs
    B=torch.sum(hadamard, dim=1)
    return torch.sum(A-B)

and I compared with torch.nn.CrossEntropyLoss

here are results

for i,j in train_dl:
    inputs=i
    targets=j
    break

outputs=model(inputs)

myCEE(outputs,targets) : tensor(147.5397, grad_fn=<SumBackward0>)
loss_func = nn.CrossEntropyLoss(reduction='sum')  : tensor(147.5397, grad_fn=<NllLossBackward>)

values were same.

I thought, because those are different functions so grad_fn are different and it won’t cause any problems.

But something happened!

After 4 epochs, loss values are turned to nan.

Contrary to myCEE, with nn.CrossEntropyLoss learning went well.

So, I wonder if there is a problem with my function.

After read some posts about nan problems, I stacked more convolutions to the model.

As a result 39-epoch training did not make an error.

Nevertheless, I’d like to know difference between myCEE and nn.CrossEntropyLoss

Advertisement

Answer

torch.nn.CrossEntropyLoss is different to your implementation because it uses a trick to counter instable computation of the exponential when using numerically big values. Given the logits output {l_1, ... l_j, ..., l_n}, the softmax is defined as:

softmax(l_i) = exp(l_i) / sum_j(exp(l_j))

The trick is to multiple both the numerator and denominator by exp(-β):

softmax(l_i) = exp(l_i)*exp(-β) / [sum_j(exp(l_j))*exp(-β)]
             = exp(l_i-β) / sum_j(exp(l_j-β))

Then the log-softmax comes down to:

logsoftmax(l_i) = l_i - β - log[sum_j(exp(l_j-β))]

In practice β is chosen as the highest logit value i.e. β = max_j(l_j).

You can read more about it on this question: Numerically Stable Softmax.

User contributions licensed under: CC BY-SA
6 People found this is helpful
Advertisement