Skip to content
Advertisement

Share the output of one class to another class python

I have two DNNs the first one returns two outputs. I want to use one of these outputs in a second class that represents another DNN as in the following example:

I want to pass the output (x) to the second class to be concatenated to another variable (v). I found a solution to make the variable (x) as a global variable, but I need another efficient solution

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        ..
    def forward(self, x):
        ..
        return x, z


class Net2(nn.Module):
    def __init__(self):
        ..
    def forward(self, v):
        y = torch.cat(v, x)
        return y

Advertisement

Answer

You should not have to rely on global variables, you need to solve this following common practices. You can pass both v, and x as parameters of the forward of Net2. Something like:

class Net(nn.Module):
    def forward(self, x):
        z = x**2
        return x, z

class Net2(nn.Module):
    def forward(self, x, v):
        y = torch.cat((v, x), dim=1)
        return y

With dummy data:

>>> net = Net()
>>> net2 = Net2()

>>> input1 = torch.rand(1,10)
>>> input2 = torch.rand(1,20)

First inference:

>>> x, z = net(input1)

Second inference:

>>> out = net2(x, input2)
>>> out.shape
torch.Size([1, 30])
User contributions licensed under: CC BY-SA
2 People found this is helpful
Advertisement