Convert 5D tensor to 4D tensor in PyTorch

Tags: , ,

In PyTorch I have a 5D tensor X of dimensions B x 9 x C x H x W. I want to convert it into a 4D tensor Y with dimensions B x 9C x H x W such that concatenation happens channel wise.

To illustrate let,

a = X[1,0,:,:,:]
b = X[1,1,:,:,:]
c = X[1,2,:,:,:]
i = X[1,8,:,:,:]

Then in the tensor Y, a to i should be channel wise concatenated.


You can easily broadcast to a new shape with torch.reshape:

b, n, c, h, w = X.shape
X = X.reshape(b, n*c, h, w)

Source: stackoverflow