Skip to content
Advertisement

How to implement batch normalization merging in python?

I have defined the model as in the code below, and I used batch normalization merging to make 3 layers into 1 linear layer.

  • The first layer of the model is a linear layer and there is no bias.
  • The second layer of the model is a batch normalization and there is no weight and bias ( affine is false )
  • The third layer of the model is a linear layer.

The variables named new_weight and new_bias are the weight and bias of the newly created linear layer, respectively.

My question is: Why is the output of the following two print functions different? And where is the wrong part in the code below the batch merge comment?

import torch
import torch.nn as nn
import torch.optim as optim

learning_rate = 0.01
in_nodes = 20
internal_nodes = 8
out_nodes = 9
batch_size = 100

# model define
class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()

        self.layer1 = nn.Linear(in_nodes, internal_nodes, bias=False)
        self.layer2 = nn.BatchNorm1d(internal_nodes, affine=False)
        self.layer3 = nn.Linear(internal_nodes, out_nodes)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x


# optimizer and criterion
model = M()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()


# training
for batch_num in range(1000):
    model.train()
    optimizer.zero_grad()

    input = torch.randn(batch_size, in_nodes)
    target = torch.ones(batch_size, out_nodes)
    
    output = model(input)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()


# batch merge
divider = torch.sqrt(model.layer2.eps + model.layer2.running_var)

w_bn = torch.diag(torch.ones(internal_nodes) / divider)
new_weight = torch.mm(w_bn, model.layer1.weight)
new_weight = torch.mm(model.layer3.weight, new_weight)

b_bn = - model.layer2.running_mean / divider
new_bias = model.layer3.bias + torch.squeeze(torch.mm(model.layer3.weight, b_bn.reshape(-1, 1)))



input = torch.randn(batch_size, in_nodes)
print(model(input))
print(torch.t(torch.mm(new_weight, torch.t(input))) + new_bias)

Advertisement

Answer

Short Answer: As far as I can tell you need a model.eval() before the line

input = torch.randn(batch_size, in_nodes)

such that the end looks like this

...
model.eval()
input = torch.randn(batch_size, in_nodes)
test_input = torch.ones(batch_size,internal_nodes)/100
print(model(input))
print(torch.t(torch.mm(new_weight, torch.t(input))) + new_bias)

with that (I tested it) the two print-statements should output the same. It fixed the weights.

Long Answer:

When using Batch-Normalization according to PyTorch documentation a default momentum of 0.1 is used to compute the running_mean and running_var. The momentum defines how much the estimated statistics and how much the new observed value influence the value.

Now when you don’t set a model.eval() statement the batch_normalization computes an updated running_mean and running_var due to the momentum in line

print(model(input))

For further details and or confirmation: Related Question, PyTorch-Documentation

User contributions licensed under: CC BY-SA
2 People found this is helpful
Advertisement