Trying to understand cross_entropy loss in PyTorch
That is because the input you give to your cross entropy function is not the probabilities as you did but the logits to be transformed into probabilities with this formula:
probas = np.exp(logits)/np.sum(np.exp(logits), axis=1)
So here the matrix of probabilities pytorch will use in your case is:
[0.5761168847658291, 0.21194155761708547, 0.21194155761708547]
[0.21194155761708547, 0.5761168847658291, 0.21194155761708547]
[0.21194155761708547, 0.21194155761708547, 0.5761168847658291]
torch.nn.functional.cross_entropy
function combines log_softmax
(softmax followed by a logarithm) and nll_loss
(negative log likelihood loss) in a single
function, i.e. it is equivalent to F.nll_loss(F.log_softmax(x, 1), y)
.
Code:
x = torch.FloatTensor([[1.,0.,0.],
[0.,1.,0.],
[0.,0.,1.]])
y = torch.LongTensor([0,1,2])
print(torch.nn.functional.cross_entropy(x, y))
print(F.softmax(x, 1).log())
print(F.log_softmax(x, 1))
print(F.nll_loss(F.log_softmax(x, 1), y))
output:
tensor(0.5514)
tensor([[-0.5514, -1.5514, -1.5514],
[-1.5514, -0.5514, -1.5514],
[-1.5514, -1.5514, -0.5514]])
tensor([[-0.5514, -1.5514, -1.5514],
[-1.5514, -0.5514, -1.5514],
[-1.5514, -1.5514, -0.5514]])
tensor(0.5514)
Read more about torch.nn.functional.cross_entropy
loss function from here.