I have two DNNs the first one returns two outputs. I want to use one of these outputs in a second class that represents another DNN as in the following example:
I want to pass the output (x) to the second class to be concatenated to another variable (v). I found a solution to make the variable (x) as a global variable, but I need another efficient solution
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
..
def forward(self, x):
..
return x, z
class Net2(nn.Module):
def __init__(self):
..
def forward(self, v):
y = torch.cat(v, x)
return y
Advertisement
Answer
You should not have to rely on global variables, you need to solve this following common practices. You can pass both v, and x as parameters of the forward of Net2. Something like:
class Net(nn.Module):
def forward(self, x):
z = x**2
return x, z
class Net2(nn.Module):
def forward(self, x, v):
y = torch.cat((v, x), dim=1)
return y
With dummy data:
>>> net = Net() >>> net2 = Net2() >>> input1 = torch.rand(1,10) >>> input2 = torch.rand(1,20)
First inference:
>>> x, z = net(input1)
Second inference:
>>> out = net2(x, input2) >>> out.shape torch.Size([1, 30])