Model not fitting well... any tips?

I am trying to create a deep learning model for diabetic retinopathy classification using the dataset from the Kaggle competition. It is a deeply imbalanced dataset, so I oversampled to compensate. The metric used is a quadratically-weighted kappa, but I still use the default losses.

I tried first training the pretrained frozen model, but the losses were increasing so i decreased the learning rate from the default to 0.0001 but the losses still were increasing. I unfroze the pretrained model and got the following learning rate finder plot:


I therefore trained the unfrozen model:learn.fit_one_cycle(8, max_lr=slice(1e-5,1e-3))

Here was the output:

This is phenomenally bad performance… what is going on? any tips for improving the performance?

A few follow-up questions (the 3rd is a question/suggestion):

  1. What model architecture are you using (ResNet-34, 50, etc.)?
  2. What image size/resolution and batch size have you specified for your ImageDataBunch?
  3. Have you tried using lr_find() on the frozen model prior to initial training?

Also, it may just be noise, but I would double-check your implementation of the quadratic-weighted κ score, as it doesn’t seem to track well with error_rate.

  1. ResNet34
  2. sz = 224, bs = 64
  3. Yes, and that’s why I changed the default LR to 0.0001

I already did do some test examples for the KappaScore metric, but I will do a couple more to make sure…

Try ResNet50. Someone else can correct me if their experience differs, but ResNet34 doesn’t typically have sufficient capacity for medical image data.

I tried ResNet50 and decreasing the batch size to 4 images, and with the learning rate finder, I found a learning rate that only marginally improves the accuracy to about 30% with slowly decreasing loss… Integrating learning rate leads to increasing loss…

I am surprised it is not working given that it has been done with fastai 0.7

I am not sure what else to try… Maybe I should play with the transformations? I am currently using default values…