I have batch of images like
torch.Size([10, 512, 512, 3])
I can loop to the images and can see 10 images. But to feed this batch to pytorch i have to convert it to
torch.Size([10, 3, 512, 512])
I tried lot of ways but unable to get the solution for this
How can we do that ?
Advertisement
Answer
Use permute
:
import torch x = torch.rand(10, 512, 512, 3) y = x.permute(0, 3, 1, 2)
x.shape
: torch.Size([10, 512, 512, 3])
y.shape
: torch.Size([10, 3, 512, 512])