I have this error: TypeError: Invalid shape (28, 28, 1) for image data
Here is my code:
JavaScript
x
29
29
1
import torch
2
import torchvision
3
import numpy as np
4
import matplotlib
5
import matplotlib.pyplot as plt
6
import torch.nn as nn
7
import torch.nn.functional as F
8
from torchvision.datasets import MNIST
9
from torchvision.transforms import ToTensor
10
from torchvision.utils import make_grid
11
from torch.utils.data.dataloader import DataLoader
12
from torch.utils.data import random_split
13
%matplotlib inline
14
15
# Load dataset
16
17
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
18
!tar -zxvf MNIST.tar.gz
19
20
from torchvision.datasets import MNIST
21
22
dataset = MNIST(root = './', train=True, download=True, transform=ToTensor())
23
#val_data = MNIST(root = './', train=False, download=True, transform=transform)
24
25
image, label = dataset[0]
26
print('image.shape:', image.shape)
27
plt.imshow(image.permute(1, 2, 0), cmap='gray') # HELP WITH THIS LINE
28
print('Label:', label)
29
I know that the pytorch does processing via this way: C x H x W, and that matplotlib does it this way: H x W x C, yet when I change it to matplotlib’s way, it gives me an error here. Am I missing something? Why does this happen?
Advertisement
Answer
plt.imshow()
expects 2D or 3D arrays. If the array has 3 dimensions then the last dimension should be 3
or 4
. In your case the array has shape (28,28,1)
and this is considered as a 3D array.
So the last dimension should be squeezed out in order to match imshow()
‘s requirements.
JavaScript
1
2
1
plt.imshow(np.squeeze(image.permute(1, 2, 0), axis = 2), cmap='gray')
2