import tensorflow as tf tf.enable_eager_execution() print(tf.keras.layers.BatchNormalization()(tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]]))) print(tf.contrib.layers.batch_norm(tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]])))"
The output of the above code (in Tensorflow 1.15) is:
tf.Tensor([[ 4.99 69.96] [ 4.99 59.97]], shape=(2, 2), dtype=float32) tf.Tensor([[ 0. 0.99998] [ 0. -0.99998]], shape=(2, 2), dtype=float32)
My problem is why the same function is giving completely different outputs. I also played with some of the parameters of the functions but the result was the same. For me, the second output is what I want. Also, pytorch’s batchnorm also gives the same output as second one. So I’m thinking its the issue with keras.
Know how to fix batchnorm in keras?
Advertisement
Answer
Batch Normalization layer has different behavior in training vs. inferencing:
During training (i.e. when using
fit()
or when calling the layer/model with the argumenttraining=True
), the layer normalizes its output using the mean and standard deviation of the current batch of inputs.During inference (i.e. when using
evaluate()
orpredict()
or when calling the layer/model with the argumenttraining=False
(which is the default), the layer normalizes its output using a moving average of the mean and standard deviation of the batches it has seen during training.
So, the first result is due to default training=False
and the second is due to default is_training=True
.
If you want the same result you may try:
x = tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]]) print(tf.keras.layers.BatchNormalization()(x, training=True).numpy().tolist()) print(tf.contrib.layers.batch_norm(x).numpy().tolist()) #output #[[0.0, 0.9999799728393555], [0.0, -0.9999799728393555]] #[[0.0, 0.9999799728393555], [0.0, -0.9999799728393555]]
or
x = tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]]) print(tf.keras.layers.BatchNormalization()(x).numpy().tolist()) print(tf.contrib.layers.batch_norm(x, is_training=False).numpy().tolist()) #output #[[4.997501850128174, 69.96502685546875], [4.997501850128174, 59.97002410888672]] #[[4.997501850128174, 69.96502685546875], [4.997501850128174, 59.97002410888672]]