I am trying to understand how torchvision interacts with mathplotlib to produce a grid of images. It’s easy to generate images and display them iteratively:
import torch import torchvision import matplotlib.pyplot as plt w = torch.randn(10,3,640,640) for i in range (0,10): z = w[i] plt.imshow(z.permute(1,2,0)) plt.show()
However, displaying these images in a grid does not seem to be as straightforward.
w = torch.randn(10,3,640,640) grid = torchvision.utils.make_grid(w, nrow=5) plt.imshow(grid) --------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-61-1601915e10f3> in <module>() 1 w = torch.randn(10,3,640,640) 2 grid = torchvision.utils.make_grid(w, nrow=5) ----> 3 plt.imshow(grid) /anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, hold, data, **kwargs) 3203 filternorm=filternorm, filterrad=filterrad, 3204 imlim=imlim, resample=resample, url=url, data=data, -> 3205 **kwargs) 3206 finally: 3207 ax._hold = washold /anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, *args, **kwargs) 1853 "the Matplotlib list!)" % (label_namer, func.__name__), 1854 RuntimeWarning, stacklevel=2) -> 1855 return func(ax, *args, **kwargs) 1856 1857 inner.__doc__ = _add_data_doc(inner.__doc__, /anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs) 5485 resample=resample, **kwargs) 5486 -> 5487 im.set_data(X) 5488 im.set_alpha(alpha) 5489 if im.get_clip_path() is None: /anaconda3/lib/python3.6/site-packages/matplotlib/image.py in set_data(self, A) 651 if not (self._A.ndim == 2 652 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]): --> 653 raise TypeError("Invalid dimensions for image data") 654 655 if self._A.ndim == 3: TypeError: Invalid dimensions for image data
Even though PyTorch’s documentation indicates that w is the correct shape, Python says that it isn’t. So I tried to permute the indices of my tensor:
w = torch.randn(10,3,640,640) grid = torchvision.utils.make_grid(w.permute(0,2,3,1), nrow=5) plt.imshow(grid) --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-62-6f2dc6313e29> in <module>() 1 w = torch.randn(10,3,640,640) ----> 2 grid = torchvision.utils.make_grid(w.permute(0,2,3,1), nrow=5) 3 plt.imshow(grid) /anaconda3/lib/python3.6/site-packages/torchvision-0.2.1-py3.6.egg/torchvision/utils.py in make_grid(tensor, nrow, padding, normalize, range, scale_each, pad_value) 83 grid.narrow(1, y * height + padding, height - padding) 84 .narrow(2, x * width + padding, width - padding) ---> 85 .copy_(tensor[k]) 86 k = k + 1 87 return grid RuntimeError: The expanded size of the tensor (3) must match the existing size (640) at non-singleton dimension 0
What’s happening here? How can I place a bunch of randomly generated images into a grid and display them?
Advertisement
Answer
There’s a small mistake in your code. torchvision.utils.make_grid()
returns a tensor which contains the grid of images. But the channel dimension has to be moved to the end since that’s what matplotlib recognizes. Below is the code that works fine:
In [107]: import torchvision # sample input (10 RGB images containing just Gaussian Noise) In [108]: batch_tensor = torch.randn(*(10, 3, 256, 256)) # (N, C, H, W) # make grid (2 rows and 5 columns) to display our 10 images In [109]: grid_img = torchvision.utils.make_grid(batch_tensor, nrow=5) # check shape In [110]: grid_img.shape Out[110]: torch.Size([3, 518, 1292]) # reshape and plot (because matplotlib needs channel as the last dimension) In [111]: plt.imshow(grid_img.permute(1, 2, 0)) Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Out[111]: <matplotlib.image.AxesImage at 0x7f62081ef080>
which shows the output as: