Skip to content
Advertisement

Confusion when displaying an image from matplotlib.pyplot to tensorflow

I have this error: TypeError: Invalid shape (28, 28, 1) for image data

Here is my code:

import torch
import torchvision
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
%matplotlib inline

# Load dataset

!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

from torchvision.datasets import MNIST

dataset = MNIST(root = './', train=True, download=True, transform=ToTensor())
#val_data = MNIST(root = './', train=False, download=True, transform=transform)

image, label = dataset[0]
print('image.shape:', image.shape)
plt.imshow(image.permute(1, 2, 0), cmap='gray') # HELP WITH THIS LINE
print('Label:', label)

I know that the pytorch does processing via this way: C x H x W, and that matplotlib does it this way: H x W x C, yet when I change it to matplotlib’s way, it gives me an error here. Am I missing something? Why does this happen?

Advertisement

Answer

plt.imshow() expects 2D or 3D arrays. If the array has 3 dimensions then the last dimension should be 3 or 4. In your case the array has shape (28,28,1) and this is considered as a 3D array.

So the last dimension should be squeezed out in order to match imshow()‘s requirements.

plt.imshow(np.squeeze(image.permute(1, 2, 0), axis = 2), cmap='gray')
User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement