Skip to content
Advertisement

How to efficiently draw a plot of a torch.nn model?

I’m exploring neural networks, and I want to model some pictures with neural network. Picture is a function that maps pixel coordinates to color, so I make my network also with 2 input variables (x, y) and 1 (shade) to 3 (R, G, B) output coordinates. For example, like this:

JavaScript

Now, I plot it like this:

JavaScript

But it looks ugly and is slow because it uses Python lists instead of numpy arrays or tensors.

I have another version of code that draws images from functions, which looks better and is 100x faster:

JavaScript

It works for functions that use numpy operations (like lambda x: x + y), but when I plug in my net in the same way as for previous function (draw_image2(lambda x, y: net(torch.Tensor([x, y])).item())), I get RuntimeError: mat1 and mat2 shapes cannot be multiplied (400x200 and 2x2), which I understand as my neural net complaining that it wants to be fed data in smaller pieces.

Is there any proper way to plot pytorch neural network output?

Advertisement

Answer

To feed a whole batch into nn.Linear(i, o), the input typically has the shape (b, i) where b is the size of the batch. If we take a look at the documentation you can actually use additional “batch”-dimensions in between. Actually since pytorch was primarily made for deep learning that is based on stochastic gradietn descent, pretty much all modules of pytorch require you to have at least one batch dimension.

So you could easily modify your second plotting function to something like:

JavaScript

Note that the with torch.no_grad() is not necessary for it to work, but it will save you some time. Depending on your network architecture it might also be worth to set your network to eval mode (net.eval()) first. Finally the .to(device)/.cpu() is also not necessary if you’re not using your GPU.

User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement