I have defined the model as in the code below, and I used batch normalization merging to make 3 layers into 1 linear layer.
- The first layer of the model is a linear layer and there is no bias.
- The second layer of the model is a batch normalization and there is no weight and bias ( affine is false )
- The third layer of the model is a linear layer.
The variables named new_weight and new_bias are the weight and bias of the newly created linear layer, respectively.
My question is: Why is the output of the following two print functions different? And where is the wrong part in the code below the batch merge comment?
import torch import torch.nn as nn import torch.optim as optim learning_rate = 0.01 in_nodes = 20 internal_nodes = 8 out_nodes = 9 batch_size = 100 # model define class M(nn.Module): def __init__(self): super(M, self).__init__() self.layer1 = nn.Linear(in_nodes, internal_nodes, bias=False) self.layer2 = nn.BatchNorm1d(internal_nodes, affine=False) self.layer3 = nn.Linear(internal_nodes, out_nodes) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) return x # optimizer and criterion model = M() optimizer = optim.SGD(model.parameters(), lr=learning_rate) criterion = nn.MSELoss() # training for batch_num in range(1000): model.train() optimizer.zero_grad() input = torch.randn(batch_size, in_nodes) target = torch.ones(batch_size, out_nodes) output = model(input) loss = criterion(output, target) loss.backward() optimizer.step() # batch merge divider = torch.sqrt(model.layer2.eps + model.layer2.running_var) w_bn = torch.diag(torch.ones(internal_nodes) / divider) new_weight = torch.mm(w_bn, model.layer1.weight) new_weight = torch.mm(model.layer3.weight, new_weight) b_bn = - model.layer2.running_mean / divider new_bias = model.layer3.bias + torch.squeeze(torch.mm(model.layer3.weight, b_bn.reshape(-1, 1))) input = torch.randn(batch_size, in_nodes) print(model(input)) print(torch.t(torch.mm(new_weight, torch.t(input))) + new_bias)
Advertisement
Answer
Short Answer: As far as I can tell you need a model.eval()
before the line
input = torch.randn(batch_size, in_nodes)
such that the end looks like this
... model.eval() input = torch.randn(batch_size, in_nodes) test_input = torch.ones(batch_size,internal_nodes)/100 print(model(input)) print(torch.t(torch.mm(new_weight, torch.t(input))) + new_bias)
with that (I tested it) the two print
-statements should output the same. It fixed the weights.
Long Answer:
When using Batch-Normalization
according to PyTorch documentation a default momentum
of 0.1 is used to compute the running_mean
and running_var
. The momentum defines how much the estimated statistics and how much the new observed value influence the value.
Now when you don’t set a model.eval()
statement the batch_normalization
computes an updated running_mean
and running_var
due to the momentum in line
print(model(input))
For further details and or confirmation: Related Question, PyTorch-Documentation