I guess I have a pretty simple problem. Let’s take the following tensor of length 6
t = torch.tensor([10., 20., 30., 40., 50., 60.])
Now I would like to to access only the elements at specific indices, lets say at [0, 3, 4]
. So I would like to return
# exptected output tensor([10., 40., 50.])
I found torch.index_select which worked great for a tensor of two dimensions, e.g. dimension (2, 4)
, but not for the given t
for example.
How can access a set of elements based on a given list of indices in a 1-d tensor without using a for loop?
Advertisement
Answer
You can in fact use index_select
for this:
t = torch.tensor([10., 20., 30., 40., 50., 60.]) output = torch.index_select(t, 0, torch.LongTensor([0, 3, 4])) # output: tensor([10., 40., 50.])
You just need to specify the dimension (0) as the second parameter. This is the only valid dimension to specify for a 1-d input tensor.