Batch Normalization with a large batch size breaks validation accuracy

I’ve been experimenting with the MNIST dataset and the ideas @jeremy mentioned in Lesson 3. Following an advice on the forums, I chose the largest batch size that my GPU supports: 4096 (MNIST is very small). To my surprise, after introducing Batch Normalisation, the validation accuracy breaks down to 0.1-0.2 (at random) and does not improve even after training for tens of epochs. Note that setting a batch_size of 4096, MNIST trains one epoch in 14 steps.
I have tried multiple combinations of Keras, Tensorflow, Learning Rate. The only variable that improves validation accuracy is either removing BatchNormalisation or reducing the batch size (I set axis=1 for Theano and axis=3 for TensorFlow).

Is there a mathematical limitation that I am missing?

For reference, I am using this Keras model:

if channels_last:
    shape = (28, 28, 1)
    bn_axis = 3
else:
    shape = (1, 28, 28)
    bn_axis = 1
def make_model():
    model = Sequential([
        layers.Lambda(norm, input_shape=shape, output_shape=shape),

        layers.ZeroPadding2D(),
        layers.Convolution2D(32, 3, activation="relu"),
        layers.ZeroPadding2D(),
        layers.BatchNormalization(axis=bn_axis),
        layers.Convolution2D(32, 3, activation="relu"),
        layers.MaxPooling2D(),

        layers.ZeroPadding2D(),
        layers.BatchNormalization(axis=bn_axis),
        layers.Convolution2D(64, 3, activation="relu"),
        layers.ZeroPadding2D(),
        layers.BatchNormalization(axis=bn_axis),
        layers.Convolution2D(64, 3, activation="relu"),
        layers.MaxPooling2D(),

        layers.Flatten(),
        layers.BatchNormalization(),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dense(10, activation='softmax'),
    ])
    model.compile(optimizer=Adam(), loss=losses.categorical_crossentropy, metrics=["accuracy"])
    return model

model = make_model()

This model is identical to the MNIST model in the course1 notebooks folder, except for the Zero padding. Removing padding does not help, if you are wondering.

Output from a run with batch_size=2**12:

Epoch 1/10
15/14 [==============================] - 7s 462ms/step - loss: 0.1141 - acc: 0.9645 - val_loss: 6.3945 - val_acc: 0.1135
Epoch 2/10
15/14 [==============================] - 6s 414ms/step - loss: 0.0870 - acc: 0.9731 - val_loss: 7.9048 - val_acc: 0.1135
Epoch 3/10
15/14 [==============================] - 6s 425ms/step - loss: 0.0714 - acc: 0.9773 - val_loss: 9.0508 - val_acc: 0.1135
Epoch 4/10
15/14 [==============================] - 6s 409ms/step - loss: 0.0648 - acc: 0.9799 - val_loss: 9.3166 - val_acc: 0.1135
Epoch 5/10
15/14 [==============================] - 6s 413ms/step - loss: 0.0584 - acc: 0.9817 - val_loss: 9.5930 - val_acc: 0.1135
Epoch 6/10
15/14 [==============================] - 6s 425ms/step - loss: 0.0549 - acc: 0.9831 - val_loss: 9.8054 - val_acc: 0.1135
Epoch 7/10
15/14 [==============================] - 6s 408ms/step - loss: 0.0498 - acc: 0.9844 - val_loss: 9.2267 - val_acc: 0.1135
Epoch 8/10
15/14 [==============================] - 6s 417ms/step - loss: 0.0455 - acc: 0.9859 - val_loss: 8.5109 - val_acc: 0.1135
Epoch 9/10
15/14 [==============================] - 6s 412ms/step - loss: 0.0429 - acc: 0.9868 - val_loss: 7.9679 - val_acc: 0.1135
Epoch 10/10
15/14 [==============================] - 6s 414ms/step - loss: 0.0412 - acc: 0.9873 - val_loss: 7.6021 - val_acc: 0.1141

The same network with batch_size=2**6:

Epoch 1/10
938/937 [==============================] - 33s 35ms/step - loss: 0.0904 - acc: 0.9717 - val_loss: 0.0294 - val_acc: 0.9904
Epoch 2/10
938/937 [==============================] - 33s 35ms/step - loss: 0.0780 - acc: 0.9760 - val_loss: 0.0268 - val_acc: 0.9902
Epoch 3/10
938/937 [==============================] - 33s 35ms/step - loss: 0.0739 - acc: 0.9773 - val_loss: 0.0203 - val_acc: 0.9930
Epoch 4/10
938/937 [==============================] - 33s 35ms/step - loss: 0.0677 - acc: 0.9795 - val_loss: 0.0223 - val_acc: 0.9932
Epoch 5/10
938/937 [==============================] - 34s 37ms/step - loss: 0.0657 - acc: 0.9804 - val_loss: 0.0197 - val_acc: 0.9942
Epoch 6/10
938/937 [==============================] - 36s 38ms/step - loss: 0.0596 - acc: 0.9825 - val_loss: 0.0205 - val_acc: 0.9939
Epoch 7/10
938/937 [==============================] - 33s 35ms/step - loss: 0.0555 - acc: 0.9828 - val_loss: 0.0161 - val_acc: 0.9954
Epoch 8/10
938/937 [==============================] - 32s 34ms/step - loss: 0.0538 - acc: 0.9832 - val_loss: 0.0165 - val_acc: 0.9945
Epoch 9/10
938/937 [==============================] - 32s 34ms/step - loss: 0.0483 - acc: 0.9851 - val_loss: 0.0225 - val_acc: 0.9941
Epoch 10/10
938/937 [==============================] - 32s 34ms/step - loss: 0.0514 - acc: 0.9843 - val_loss: 0.0140 - val_acc: 0.9949

Yes I also observed similar results with other datasets/problems with deep CNNs using many successive Batch Normalization layers. I am also interested if someone has a mathematical explanation for this observed result.

I don’t know if normalization on too much data for many successive layers affects the capacity of the network to generalize well­. Apparently, the network needs to fit the training data on non-generalizable features to converge.

Or like it is often the case in computer science, it is just a memory bug in the API …

Large batch sizes tend to decrease learning accuracy. So what you should do instead is start training with smaller batch sizes and gradually increase them over the course of learning. See this cool paper.

4 Likes

One more thing. With batch sizes THAT large you need to ramp up learning rate way above the Adam’s default of 1e-3.

Do you mean that you can ramp up the learning rate, or that you need to? AFAIK, large batch sizes allow you to use a higher LR; but I am not aware of any problems if you keep to the lower LR (apart from training taking longer than necessary).