Learn.unfreeze() causes spike in loss when training Unet on image segmentation challenge


I’m training a Unet (resnet152) with custom decoder on an image segmentation challenge, it’s on Kaggle right now. The idea is very similar to Carvana - predict a binary mask for each image.

I naturally thought I’d use the fastai library for this but every time I unfreeze the encoder I get a large spike in training and validation loss. My first guess was obviously that this relates to having the learning rate too high but even if I do something like:

lrs = np.array([lr/10000, lr/1000, lr/10])

I get a large spike. Any ideas?



I have seen something similar with classification problems where the images I train on differ significantly from those the model was pretrained on.

I have some ideas why that might be the case but don’t want to speculate. In my case the spike happens but as I continue to train loss decreases below what was attainable with the lower parts of the NN frozen.

Might be the ball sliding downhill is a nice analogy for learning about optimization algorithms (seems to lend itself particularly well to understanding the momentum bit!) but as we gather experience on training some aspects of it break down.

If your loss spikes and then keeps decreasing I wouldn’t worry too much about this.

Hi @radek,

Thanks for the reply - what you say makes sense. Two observations, if I may:

  1. Technically as I let a and b get large for lrs = np.array([lr/a, lr/b, lr/10]) then it should be equivalent to not unfreezing - though I am still in general noticing a loss spike.
  2. This does raise the question as to whether it’s worth just unfreezing all together from the start (a clear violation of the recommended wisdom) as the spike usually resets me to where I was after a couple of epochs before unfreezing.

Finally, please speculate! I am new to fastai in the wild and if this is to do with the library’s rough edges I’d like to know (though more likely my lack of understanding).

You are right - I wonder if even small changes to those lower layers have a multiplicative effect for layers higher up the stack. Maybe even very small changes cause a big issue down the road.

Wonder what would happen if we unfreeze the layers and set the lrs for them to 0, if that is even possible.

The behavior we both see could be due to just the effect that unfreezing of lower layers has when retraining on vastly different data or it could be something about the library (say the effect of unfreezing batch norm if that is what happens here when we call unfreeze).

It is interesting that this is happening but at this point I don’t have the bandwidth to look into this further. I would still continue to first train the higher layers of the network first before unfreezing the lower layers though I myself stopped using discriminative learning rates.

Anyhow. might be interesting to see if the same behavior exists on datasets that are similar to the ones the NN was pretrained on (say seeing if this behavior exists on cats vs dogs with a CNN pretrained on imagenet).

Still, it might be that we are reading too much into this behavior, more than it is worth it at this point.

This sounds like a question that could be best answered via experimenting :slight_smile:

1 Like

It could be the opposite - larger changes in the later layers have a disproportionate effect on the back layers when they propagate back for the first time.

If the loss is still spiking with large coefficients for the early layers, I would expect the problem to lie in the later layers. Rather than just playing with a and b, try playing around with the overall lr value.

Say you want to train at
lr = 1e-2
lrs = np.array([lr/100,lr/10,lr])

but that keeps diverging on you.
Try start at a lower lr value and work up to the one you want. Do a few epochs at lr = 1e-5, 1e-4, 1e-3, then finally 1e-2.

I’ve seen this approach help weird divergence after unfreezing. My intuition is that it lets the lower layers ‘adjust’ to whatever the head has retrained to do without actually changing the weights much.

Also by taking the model through a learning rate range you might find your problem is just the overall learning rate (ie 1e-4 is fine but 1e-3 leads to divergence).

Hi @radek and @KarlH - thanks both for sensible comments. I’m away at the moment so will revert in just over a week with any observations/results of experiments.

Thanks again