I am trying to convert this code passing it with pysyft refference
like this :
JavaScript
x
33
33
1
class SyNet(sy.Module):
2
def __init__(self,embedding_size, num_numerical_cols, output_size, layers, p ,torch_ref):
3
super(SyNet, self ).__init__( embedding_size, num_numerical_cols , output_size , layers , p=0.4 ,torch_ref=torch_ref )
4
self.all_embeddings=self.torch_ref.nn.ModuleList([nn.Embedding(ni, nf) for ni, nf in embedding_size])
5
self.embedding_dropout=self.torch_ref.nn.Dropout(p)
6
self.batch_norm_num=self.torch_ref.nn.BatchNorm1d(num_numerical_cols)
7
8
all_layers= []
9
num_categorical_cols = sum((nf for ni, nf in embedding_size))
10
input_size = num_categorical_cols + num_numerical_cols
11
12
for i in layers:
13
all_layers.append(self.torch_ref.nn.Linear(input_size,i))
14
all_layers.append(self.torch_ref.nn.ReLU(inplace=True))
15
all_layers.append(self.torch_ref.nn.BatchNorm1d(i))
16
all_layers.append(self.torch_ref.nn.Dropout(p))
17
input_size = i
18
19
all_layers.append(self.torch_ref.nn.Linear(layers[-1], output_size))
20
21
self.layers = self.torch_ref.nn.Sequential(*all_layers)
22
23
def forward(self, x_categorical, x_numerical):
24
embeddings= []
25
for i,e in enumerate(self.all_embeddings):
26
embeddings.append(e(x_categorical[:,i]))
27
28
x_numerical = self.batch_norm_num(x_numerical)
29
x = self.torch_ref.cat([x, x_numerical], 1)
30
x = self.layers(x)
31
return x
32
33
But when I try to create a instance of the model
JavaScript
1
3
1
model = SyNet( categorical_embedding_sizes, numerical_data.shape[1], 2, [200,100,50], p=0.4 ,torch_ref= th)
2
3
I got a TypeError
TypeError: multiple values for argument ‘torch_ref’
I tried to change the order of the arguments but i got an error about positional arguments . Can you help me , I am not very experienced in classes and functions (oop)
Thank you in advance !
Advertisement
Answer
Looking at PySyft source code for Module
. The constructor of your class parent only takes a single argument: torch_ref
.
You should therefore call the super constructor with:
JavaScript
1
2
1
super(SyNet, self).__init__(torch_ref=torch_ref) # line 3
2
removing all arguments but torch_ref
from the call.