Skip to content
Advertisement

`torch.gather` without unbroadcasting

I have some batched input x of shape [batch, time, feature], and some batched indices i of shape [batch, new_time] which I want to gather into the time dim of x. As output of this operation I want a tensor y of shape [batch, new_time, feature] with values like this:

y[b, t', f] = x[b, i[b, t'], f]

In Tensorflow, I can accomplish this by using the batch_dims: int argument of tf.gather: y = tf.gather(x, i, axis=1, batch_dims=1).

In PyTorch, I can think of some functions which do similar things:

  1. torch.gather of course, but this does not have an argument similar to Tensorflow’s batch_dims. The output of torch.gather will always have the same shape as the indices. So I would need to unbroadcast the feature dim into i before passing it to torch.gather.

  2. torch.index_select, but here, the indices must be one-dimensional. So to make it work I would need to unbroadcast x to add a “batch * new_time” dim, and then after torch.index_select reshape the output.

  3. torch.nn.functional.embedding. Here, the embedding matrices would correspond to x. But this embedding function does not support the weights to be batched, so I run into the same issue as for torch.index_select (looking at the code, tf.embedding uses torch.index_select under the hood).

Is it possible to accomplish such gather operation without relying on unbroadcasting which is inefficient for large dims?

Advertisement

Answer

This is actually the most frequent case: when input and index tensors don’t perfectly match the number of dimensions. You can still utilize torch.gather though since you can rewrite your expression:

y[b, t, f] = x[b, i[b, t], f]

as:

y[b, t, f] = x[b, i[b, t, f], f]

which ensures all three tensors have an equal number of dimensions. This reveals a third dimension on i, which we can easily create for free by unsqueezing a dimension and expanding it to the shape of x. You can do so with i[:,None].expand_as(x).

Here is a minimal example:

>>> b = 2; t = 3; f = 1
>>> x = torch.rand(b, t, f)
>>> i = torch.randint(0, t, (b, f))

>>> x.gather(1, i[:,None].expand_as(x))
User contributions licensed under: CC BY-SA
1 People found this is helpful
Advertisement