I have this error: TypeError: Invalid shape (28, 28, 1) for image data
Here is my code:
import torch import torchvision import numpy as np import matplotlib import matplotlib.pyplot as plt import torch.nn as nn import torch.nn.functional as F from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from torchvision.utils import make_grid from torch.utils.data.dataloader import DataLoader from torch.utils.data import random_split %matplotlib inline # Load dataset !wget www.di.ens.fr/~lelarge/MNIST.tar.gz !tar -zxvf MNIST.tar.gz from torchvision.datasets import MNIST dataset = MNIST(root = './', train=True, download=True, transform=ToTensor()) #val_data = MNIST(root = './', train=False, download=True, transform=transform) image, label = dataset[0] print('image.shape:', image.shape) plt.imshow(image.permute(1, 2, 0), cmap='gray') # HELP WITH THIS LINE print('Label:', label)
I know that the pytorch does processing via this way: C x H x W, and that matplotlib does it this way: H x W x C, yet when I change it to matplotlib’s way, it gives me an error here. Am I missing something? Why does this happen?
Advertisement
Answer
plt.imshow()
expects 2D or 3D arrays. If the array has 3 dimensions then the last dimension should be 3
or 4
. In your case the array has shape (28,28,1)
and this is considered as a 3D array.
So the last dimension should be squeezed out in order to match imshow()
‘s requirements.
plt.imshow(np.squeeze(image.permute(1, 2, 0), axis = 2), cmap='gray')