Hi, I’ve had this issue a number of times now, so thought to make a little recap of it and possible solutions etc. to help people in the future.
Issue: Model predicts one of the 2 (or more) possible classes for all data it sees*
Confirming issue is occurring: Method 1: accuracy for model stays around 0.5 while training (or 1/n where n is number of classes). Method 2: Get the counts of each class in predictions and confirm it’s predicting all one class.
Fixes/Checks (in somewhat of an order):
-
Double Check Model Architecture: use
model.summary()
, inspect the model. - Check Data Labels: make sure the labelling of your train data hasn’t got mixed up somewhere in the preprocessing etc. (it happens!)
-
Check Train Data Feeding Is Randomised: make sure you are not feeding your train data to the model one class at a time. For instance if using
ImageDataGenerator().flow_from_directory(PATH)
, check that paramshuffle=True
and thatbatch_size
is greater than 1. -
Check Pre-Trained Layers Are Not Trainable:** If using a pre-trained model, ensure that any layers that use pre-trained weights are NOT initially trainable. For the first epochs, only the newly added (randomly initialised) layers should be trainable;
for layer in pretrained_model.layers: layer.trainable = False
should be somewhere in your code. -
Ramp Down Learning Rate: Keep reducing your learning rate by factors of 10 and retrying. Note you will have to fully reinitialize the layers you are trying to train each time you try a new learning rate. (For instance, I had this issue that was only solved once I got down to
lr=1e-6
, so keep going!)
If any of you know of more fixes/checks that could possible get the model training properly then please do contribute and I’ll try to update the list.
**Note that is common to make more of the pretrained model trainable, once the new layers have been initially trained “enough”
*Other names for the issue to help searches get here…
keras tensorflow theano CNN convolutional neural network bad training stuck fixed not static broken bug bugged jammed training only 0.5 accuracy only predicts one single class wont train model stuck on class model resetting itself between epochs