Skip to content
Advertisement

Select pytorch tensor elements by list of indices

I guess I have a pretty simple problem. Let’s take the following tensor of length 6

t = torch.tensor([10., 20., 30., 40., 50., 60.])

Now I would like to to access only the elements at specific indices, lets say at [0, 3, 4]. So I would like to return

# exptected output 
tensor([10., 40., 50.])

I found torch.index_select which worked great for a tensor of two dimensions, e.g. dimension (2, 4), but not for the given t for example.

How can access a set of elements based on a given list of indices in a 1-d tensor without using a for loop?

Advertisement

Answer

You can in fact use index_select for this:

t = torch.tensor([10., 20., 30., 40., 50., 60.])
output = torch.index_select(t, 0, torch.LongTensor([0, 3, 4]))
# output: tensor([10., 40., 50.])

You just need to specify the dimension (0) as the second parameter. This is the only valid dimension to specify for a 1-d input tensor.

User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement