I have some batched input x of shape [batch, time, feature], and some batched indices i of shape [batch, new_time] which I want to gather into the time dim of x. As output of this operation I want a tensor y of shape [batch, new_time, feature] with values like this:
y[b, t', f] = x[b, i[b, t'], f]
In Tensorflow, I can accomplish this by using the batch_dims: int argument of tf.gather: y = tf.gather(x, i, axis=1, batch_dims=1).
In PyTorch, I can think of some functions which do similar things:
- torch.gatherof course, but this does not have an argument similar to Tensorflow’s- batch_dims. The output of- torch.gatherwill always have the same shape as the indices. So I would need to unbroadcast the- featuredim into- ibefore passing it to- torch.gather.
- torch.index_select, but here, the indices must be one-dimensional. So to make it work I would need to unbroadcast- xto add a “- batch * new_time” dim, and then after- torch.index_selectreshape the output.
- torch.nn.functional.embedding. Here, the embedding matrices would correspond to- x. But this embedding function does not support the weights to be batched, so I run into the same issue as for- torch.index_select(looking at the code,- tf.embeddinguses- torch.index_selectunder the hood).
Is it possible to accomplish such gather operation without relying on unbroadcasting which is inefficient for large dims?
Advertisement
Answer
This is actually the most frequent case: when input and index tensors don’t perfectly match the number of dimensions. You can still utilize torch.gather though since you can rewrite your expression:
y[b, t, f] = x[b, i[b, t], f]
as:
y[b, t, f] = x[b, i[b, t, f], f]
which ensures all three tensors have an equal number of dimensions. This reveals a third dimension on i, which we can easily create for free by unsqueezing a dimension and expanding it to the shape of x. You can do so with i[:,None].expand_as(x).
Here is a minimal example:
>>> b = 2; t = 3; f = 1 >>> x = torch.rand(b, t, f) >>> i = torch.randint(0, t, (b, f)) >>> x.gather(1, i[:,None].expand_as(x))
