I have two DNNs the first one returns two outputs. I want to use one of these outputs in a second class that represents another DNN as in the following example:
I want to pass the output (x) to the second class to be concatenated to another variable (v). I found a solution to make the variable (x) as a global variable, but I need another efficient solution
import torch.nn as nn import torch.nn.functional as F import torch.optim as optim class Net(nn.Module): def __init__(self): .. def forward(self, x): .. return x, z class Net2(nn.Module): def __init__(self): .. def forward(self, v): y = torch.cat(v, x) return y
Advertisement
Answer
You should not have to rely on global variables, you need to solve this following common practices. You can pass both v
, and x
as parameters of the forward
of Net2
. Something like:
class Net(nn.Module): def forward(self, x): z = x**2 return x, z class Net2(nn.Module): def forward(self, x, v): y = torch.cat((v, x), dim=1) return y
With dummy data:
>>> net = Net() >>> net2 = Net2() >>> input1 = torch.rand(1,10) >>> input2 = torch.rand(1,20)
First inference:
>>> x, z = net(input1)
Second inference:
>>> out = net2(x, input2) >>> out.shape torch.Size([1, 30])