I’ve been experimenting with the MNIST dataset and the ideas @jeremy mentioned in Lesson 3. Following an advice on the forums, I chose the largest batch size that my GPU supports: 4096 (MNIST is very small). To my surprise, after introducing Batch Normalisation, the validation accuracy breaks down to 0.1-0.2 (at random) and does not improve even after training for tens of epochs. Note that setting a batch_size of 4096, MNIST trains one epoch in 14 steps.
I have tried multiple combinations of Keras, Tensorflow, Learning Rate. The only variable that improves validation accuracy is either removing BatchNormalisation or reducing the batch size (I set axis=1 for Theano and axis=3 for TensorFlow).
Is there a mathematical limitation that I am missing?
For reference, I am using this Keras model:
if channels_last:
shape = (28, 28, 1)
bn_axis = 3
else:
shape = (1, 28, 28)
bn_axis = 1
def make_model():
model = Sequential([
layers.Lambda(norm, input_shape=shape, output_shape=shape),
layers.ZeroPadding2D(),
layers.Convolution2D(32, 3, activation="relu"),
layers.ZeroPadding2D(),
layers.BatchNormalization(axis=bn_axis),
layers.Convolution2D(32, 3, activation="relu"),
layers.MaxPooling2D(),
layers.ZeroPadding2D(),
layers.BatchNormalization(axis=bn_axis),
layers.Convolution2D(64, 3, activation="relu"),
layers.ZeroPadding2D(),
layers.BatchNormalization(axis=bn_axis),
layers.Convolution2D(64, 3, activation="relu"),
layers.MaxPooling2D(),
layers.Flatten(),
layers.BatchNormalization(),
layers.Dense(512, activation='relu'),
layers.BatchNormalization(),
layers.Dense(10, activation='softmax'),
])
model.compile(optimizer=Adam(), loss=losses.categorical_crossentropy, metrics=["accuracy"])
return model
model = make_model()
This model is identical to the MNIST model in the course1 notebooks folder, except for the Zero padding. Removing padding does not help, if you are wondering.
Output from a run with batch_size=2**12:
Epoch 1/10
15/14 [==============================] - 7s 462ms/step - loss: 0.1141 - acc: 0.9645 - val_loss: 6.3945 - val_acc: 0.1135
Epoch 2/10
15/14 [==============================] - 6s 414ms/step - loss: 0.0870 - acc: 0.9731 - val_loss: 7.9048 - val_acc: 0.1135
Epoch 3/10
15/14 [==============================] - 6s 425ms/step - loss: 0.0714 - acc: 0.9773 - val_loss: 9.0508 - val_acc: 0.1135
Epoch 4/10
15/14 [==============================] - 6s 409ms/step - loss: 0.0648 - acc: 0.9799 - val_loss: 9.3166 - val_acc: 0.1135
Epoch 5/10
15/14 [==============================] - 6s 413ms/step - loss: 0.0584 - acc: 0.9817 - val_loss: 9.5930 - val_acc: 0.1135
Epoch 6/10
15/14 [==============================] - 6s 425ms/step - loss: 0.0549 - acc: 0.9831 - val_loss: 9.8054 - val_acc: 0.1135
Epoch 7/10
15/14 [==============================] - 6s 408ms/step - loss: 0.0498 - acc: 0.9844 - val_loss: 9.2267 - val_acc: 0.1135
Epoch 8/10
15/14 [==============================] - 6s 417ms/step - loss: 0.0455 - acc: 0.9859 - val_loss: 8.5109 - val_acc: 0.1135
Epoch 9/10
15/14 [==============================] - 6s 412ms/step - loss: 0.0429 - acc: 0.9868 - val_loss: 7.9679 - val_acc: 0.1135
Epoch 10/10
15/14 [==============================] - 6s 414ms/step - loss: 0.0412 - acc: 0.9873 - val_loss: 7.6021 - val_acc: 0.1141
The same network with batch_size=2**6:
Epoch 1/10
938/937 [==============================] - 33s 35ms/step - loss: 0.0904 - acc: 0.9717 - val_loss: 0.0294 - val_acc: 0.9904
Epoch 2/10
938/937 [==============================] - 33s 35ms/step - loss: 0.0780 - acc: 0.9760 - val_loss: 0.0268 - val_acc: 0.9902
Epoch 3/10
938/937 [==============================] - 33s 35ms/step - loss: 0.0739 - acc: 0.9773 - val_loss: 0.0203 - val_acc: 0.9930
Epoch 4/10
938/937 [==============================] - 33s 35ms/step - loss: 0.0677 - acc: 0.9795 - val_loss: 0.0223 - val_acc: 0.9932
Epoch 5/10
938/937 [==============================] - 34s 37ms/step - loss: 0.0657 - acc: 0.9804 - val_loss: 0.0197 - val_acc: 0.9942
Epoch 6/10
938/937 [==============================] - 36s 38ms/step - loss: 0.0596 - acc: 0.9825 - val_loss: 0.0205 - val_acc: 0.9939
Epoch 7/10
938/937 [==============================] - 33s 35ms/step - loss: 0.0555 - acc: 0.9828 - val_loss: 0.0161 - val_acc: 0.9954
Epoch 8/10
938/937 [==============================] - 32s 34ms/step - loss: 0.0538 - acc: 0.9832 - val_loss: 0.0165 - val_acc: 0.9945
Epoch 9/10
938/937 [==============================] - 32s 34ms/step - loss: 0.0483 - acc: 0.9851 - val_loss: 0.0225 - val_acc: 0.9941
Epoch 10/10
938/937 [==============================] - 32s 34ms/step - loss: 0.0514 - acc: 0.9843 - val_loss: 0.0140 - val_acc: 0.9949