Skip to content
Advertisement

How to efficiently draw a plot of a torch.nn model?

I’m exploring neural networks, and I want to model some pictures with neural network. Picture is a function that maps pixel coordinates to color, so I make my network also with 2 input variables (x, y) and 1 (shade) to 3 (R, G, B) output coordinates. For example, like this:

import torch.nn as nn

net = nn.Sequential(
    nn.Linear(2, 2),
    nn.Sigmoid(),
    nn.Linear(2, 1),
)

Now, I plot it like this:

import matplotlib.pyplot as plt
import numpy as np

def draw_image1(f):
    image = []
    y = 1
    delta = 0.005
    while y > 0:
        x = 0
        row = []
        while x < 1:
            row.append(f(x, y))
            x += delta
        image.append(row)
        y -= delta
 
    plt.imshow(image, extent=[0, 1, 0, 1], cmap='winter')
    plt.draw()

draw_image1(lambda x, y: net(torch.Tensor([x, y])).item())

But it looks ugly and is slow because it uses Python lists instead of numpy arrays or tensors.

I have another version of code that draws images from functions, which looks better and is 100x faster:

def draw_image2(f):
    x = np.linspace(0, 1, num = 200)
    y = np.linspace(0, 1, num = 200)
    X, Y = np.meshgrid(x, y)

    image = f(X, Y)

    plt.imshow(image, extent=[0, 1, 0, 1], cmap='winter')
    plt.draw()

It works for functions that use numpy operations (like lambda x: x + y), but when I plug in my net in the same way as for previous function (draw_image2(lambda x, y: net(torch.Tensor([x, y])).item())), I get RuntimeError: mat1 and mat2 shapes cannot be multiplied (400x200 and 2x2), which I understand as my neural net complaining that it wants to be fed data in smaller pieces.

Is there any proper way to plot pytorch neural network output?

Advertisement

Answer

To feed a whole batch into nn.Linear(i, o), the input typically has the shape (b, i) where b is the size of the batch. If we take a look at the documentation you can actually use additional “batch”-dimensions in between. Actually since pytorch was primarily made for deep learning that is based on stochastic gradietn descent, pretty much all modules of pytorch require you to have at least one batch dimension.

So you could easily modify your second plotting function to something like:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

net = nn.Sequential(
    nn.Linear(2, 2),
    nn.Sigmoid(),
    nn.Linear(2, 1),
)

def draw_image2(f):
    device = torch.device('cpu')  # or use your gpu alternatively
    with torch.no_grad():  # disable building evaluation graph if you don't need it
        x = torch.linspace(0, 1, 200)
        y = torch.linspace(0, 1, 200)
        X, Y = torch.meshgrid(x, y)
        # the data dimension should be the last (2), as per documentation
        inp = torch.stack([X, Y], dim=2).to(device)  # shape = (200, 200, 2)
        image = f(inp)  # shape = (200, 200, 1)
        image = image[..., 0].detach().cpu() # shape (200, 200)
    plt.imshow(image, extent=[0, 1, 0, 1], cmap='winter')
    plt.show()
    return image

draw_image2(net)

Note that the with torch.no_grad() is not necessary for it to work, but it will save you some time. Depending on your network architecture it might also be worth to set your network to eval mode (net.eval()) first. Finally the .to(device)/.cpu() is also not necessary if you’re not using your GPU.

User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement