Skip to content
Advertisement

Pytorch softmax: What dimension to use?

The function torch.nn.functional.softmax takes two parameters: input and dim. According to its documentation, the softmax operation is applied to all slices of input along the specified dim, and will rescale them so that the elements lie in the range (0, 1) and sum to 1.

Let input be:

input = torch.randn((3, 4, 5, 6))

Suppose I want the following, so that every entry in that array is 1:

sum = torch.sum(input, dim = 3) # sum's size is (3, 4, 5, 1)

How should I apply softmax?

softmax(input, dim = 0) # Way Number 0
softmax(input, dim = 1) # Way Number 1
softmax(input, dim = 2) # Way Number 2
softmax(input, dim = 3) # Way Number 3

My intuition tells me that is the last one, but I am not sure. English is not my first language and the use of the word along seemed confusing to me because of that.

I am not very clear on what “along” means, so I will use an example that could clarify things. Suppose we have a tensor of size (s1, s2, s3, s4), and I want this to happen

Advertisement

Answer

The easiest way I can think of to make you understand is: say you are given a tensor of shape (s1, s2, s3, s4) and as you mentioned you want to have the sum of all the entries along the last axis to be 1.

sum = torch.sum(input, dim = 3) # input is of shape (s1, s2, s3, s4)

Then you should call the softmax as:

softmax(input, dim = 3)

To understand easily, you can consider a 4d tensor of shape (s1, s2, s3, s4) as a 2d tensor or matrix of shape (s1*s2*s3, s4). Now if you want the matrix to contain values in each row (axis=0) or column (axis=1) that sum to 1, then, you can simply call the softmax function on the 2d tensor as follows:

softmax(input, dim = 0) # normalizes values along axis 0
softmax(input, dim = 1) # normalizes values along axis 1

You can see the example that Steven mentioned in his answer.

User contributions licensed under: CC BY-SA
6 People found this is helpful
Advertisement