I have two DNNs the first one returns two outputs. I want to use one of these outputs in a second class that represents another DNN as in the following example:
I want to pass the output (x) to the second class to be concatenated to another variable (v). I found a solution to make the variable (x) as a global variable, but I need another efficient solution
JavaScript
x
19
19
1
import torch.nn as nn
2
import torch.nn.functional as F
3
import torch.optim as optim
4
5
class Net(nn.Module):
6
def __init__(self):
7
..
8
def forward(self, x):
9
..
10
return x, z
11
12
13
class Net2(nn.Module):
14
def __init__(self):
15
..
16
def forward(self, v):
17
y = torch.cat(v, x)
18
return y
19
Advertisement
Answer
You should not have to rely on global variables, you need to solve this following common practices. You can pass both v
, and x
as parameters of the forward
of Net2
. Something like:
JavaScript
1
10
10
1
class Net(nn.Module):
2
def forward(self, x):
3
z = x**2
4
return x, z
5
6
class Net2(nn.Module):
7
def forward(self, x, v):
8
y = torch.cat((v, x), dim=1)
9
return y
10
With dummy data:
JavaScript
1
6
1
>>> net = Net()
2
>>> net2 = Net2()
3
4
>>> input1 = torch.rand(1,10)
5
>>> input2 = torch.rand(1,20)
6
First inference:
JavaScript
1
2
1
>>> x, z = net(input1)
2
Second inference:
JavaScript
1
4
1
>>> out = net2(x, input2)
2
>>> out.shape
3
torch.Size([1, 30])
4