I have some batched input x
of shape [batch, time, feature]
, and some batched indices i
of shape [batch, new_time]
which I want to gather into the time dim of x
. As output of this operation I want a tensor y
of shape [batch, new_time, feature]
with values like this:
y[b, t', f] = x[b, i[b, t'], f]
In Tensorflow, I can accomplish this by using the batch_dims: int
argument of tf.gather
: y = tf.gather(x, i, axis=1, batch_dims=1)
.
In PyTorch, I can think of some functions which do similar things:
torch.gather
of course, but this does not have an argument similar to Tensorflow’sbatch_dims
. The output oftorch.gather
will always have the same shape as the indices. So I would need to unbroadcast thefeature
dim intoi
before passing it totorch.gather
.torch.index_select
, but here, the indices must be one-dimensional. So to make it work I would need to unbroadcastx
to add a “batch * new_time
” dim, and then aftertorch.index_select
reshape the output.torch.nn.functional.embedding
. Here, the embedding matrices would correspond tox
. But this embedding function does not support the weights to be batched, so I run into the same issue as fortorch.index_select
(looking at the code,tf.embedding
usestorch.index_select
under the hood).
Is it possible to accomplish such gather operation without relying on unbroadcasting which is inefficient for large dims?
Advertisement
Answer
This is actually the most frequent case: when input and index tensors don’t perfectly match the number of dimensions. You can still utilize torch.gather
though since you can rewrite your expression:
y[b, t, f] = x[b, i[b, t], f]
as:
y[b, t, f] = x[b, i[b, t, f], f]
which ensures all three tensors have an equal number of dimensions. This reveals a third dimension on i
, which we can easily create for free by unsqueezing a dimension and expanding it to the shape of x
. You can do so with i[:,None].expand_as(x)
.
Here is a minimal example:
>>> b = 2; t = 3; f = 1 >>> x = torch.rand(b, t, f) >>> i = torch.randint(0, t, (b, f)) >>> x.gather(1, i[:,None].expand_as(x))