I recently finished lesson 3 so I’ve been trying to take my
catsdogsredux model that I made and add batch normalization to the fully connected layers. My original model had about ~97% accuracy, but as soon as I do a single pass over my training features with a batch normalization model it tanks to ~45%.
For my original model I first trained the final dense layer, than all the dense layers, than some of the final convolutional layers. Evaluating the final model from that gave me about .969999999 accuracy. Next I took that model, split off the conv layers, and calculated my train/validation features with
conv_model.predict_generator/2. After that I set up a fully connected model with batch normalization using:
def get_bn_layers(p): return [ MaxPooling2D(input_shape=conv_layers[-1].output_shape[1:]), Flatten(), Dense(4096, activation='relu'), BatchNormalization(), Dropout(p), Dense(4096, activation='relu'), BatchNormalization(), Dropout(p), Dense(1000, activation='softmax') ] p=0.6 def get_bn_model(): bn_model = Sequential(get_bn_layers(p)) load_fc_weights_from_vgg16bn(bn_model) for l in bn_model.layers: if type(l)==Dense: l.set_weights(proc_wgts(l, 0.5, p)) bn_model.pop() for layer in bn_model.layers: layer.trainable=False bn_model.add(Dense(2, activation='softmax')) bn_model.compile(Adam(), 'categorical_crossentropy', metrics=['accuracy']) return bn_model
I then fit my train/validation features on this batchnorm model which resulted in:
bn_model.fit(trn_features, trn_labels, nb_epoch=1, validation_data=(val_features, val_labels)) Train on 23000 samples, validate on 2000 samples Epoch 1/1 23000/23000 [==============================] - 14s - loss: 2.0176 - acc: 0.4947 - val_loss: 1.1338 - val_acc: 0.4640
As far as I can tell I’m doing pretty much the same thing that is in the lesson2/3 notebooks so I’m unsure what the drastic drop in accuracy is from. I suspect it’s because my features were generated off a model where I had trained the later convolutional layers a bit while the weights from the batchnorm model are from vanilla vgg16 with batch norm, but it looks like thats the same thing Jeremy does in the lesson notebooks.