Fastai latest availabe pre-trained base-arch [imbalanced dataset]

(Himanshu) #1

What pre-trained torch.vision models are available for use in image classification? what could work best for an imbalanced dataset?

Training Set: (3662 rows)
Class - Proportion
0 - 49.29
2 - 27.28
1 - 10.10
4 - 8.06
3 - 5.27

0 Likes

(Kushajveer Singh) #2

There is no relation in what model arch to use and the ratio of different classes in your dataset. If your dataset is imbalanced, then you have to use other techniques like oversampling, data augmentation, change of loss function.

2 Likes

(Himanshu) #3

Thanks, Khusaj for the reply.

I got your point on imbalanced dataset. Coming on arch, I know resnet and densenet are most known pre-trained archs in fastai. Is there a way to use archs from below link? mostly implement and use i feel.

0 Likes

(Zachary Mueller) #4

Here’s a notebook discussing how to.

1 Like

(Himanshu) #5

Thanks @muellerzr, I will give it a shot.

0 Likes

#6

Two things.

  1. Some of the cadene models are actually available in fastai. See here.
  2. Oversampling, which is used to counteract imbalanced classes, has been implemented and added to the fastai library by me just recently! If you check here, you can see the code for the callback. It should therefore be available in the next release or you can copy and paste the relevant sections of the oversampling code into your code.
2 Likes

(Himanshu) #7

Thanks buddy, @ilovescience .

Unfortunately, I am confused with the callbacks on fastai repo. (learning)

Would you help me with the oversampling code(i.e. How would you increase the number of samples for training in databunch)?

Thanks,
Himanshu

0 Likes

#8

It’s not hard to use callbacks. Since currently the oversampling callback is in the development version of fastai, here are two simple steps for adding the Oversampling callback to your code:

  1. Add this piece of code to your program before defining your Learner:
from torch.utils.data.sampler import WeightedRandomSampler

class OverSamplingCallback(LearnerCallback):
    def __init__(self,learn:Learner, weights:torch.Tensor=None):
        super().__init__(learn)
        labels = self.learn.data.train_dl.dataset.y.items.astype(int)
        _,counts = np.unique(labels, return_counts=True)
        counts = 1. / counts
        self.weights = (weights if weights is not None else torch.DoubleTensor(counts[labels]))

    def on_train_begin(self, **kwargs):
        self.learn.data.train_dl.dl.batch_sampler = BatchSampler(
            WeightedRandomSampler(self.weights, len(self.learn.data.train_dl.dataset)),
            self.learn.data.train_dl.batch_size, False)
  1. Then pass the callback to the callback_fns argument when creating a learner:
learn = cnn_learner(data, base_arch=models.resnet50, metrics = [accuracy],callback_fns=[partial(OverSamplingCallback)])

That’s it!

The Oversampling callback works by using a Pytorch weighted sampler that will sample the training set so that all classes have similar number of samples.

2 Likes

(Himanshu) #9

@ilovescience: Thank you. I will test this and see how it goes.

0 Likes

(Hao) #10

Hi,
class imbalance is an interesting topic to try. Base on my understanding, there are two ways you can try.

One is people discussed here, over-sampling. You can manually use pandas to do the oversampling part (if you have csv file as input). The idea is to sample all the classes to have roughly equal amount of data, and data argumentation will save you from overfitting when you running the same image many times in one epoch. So you can see the downside, if you have 1 class with 1000 items and the other only has 1 item. If you do oversampling, the result probably wont be good.

The other work around is focal loss, which you can check part 2 2018 for details.

Sometimes, Siamese network (one-shot learning) also works for this kind of problem.

Hope this point you directions :slight_smile:

0 Likes

(Himanshu) #11

Thanks @heye0507 for the direction. Do you have an idea or how can Siamese network (one-shot learning) be implemented in fastai? Any inputs will be much appreciated.

@ilovescience, @heye0507: Few more queries,

  1. Have you ever came across ordinal regression? What could be the best way to solve this using fastai? I mean metric, optimizing function, handling outliers and stuff?

  2. I will be digging around but if you know, How can you train a single problem for classification as well as ordinal regression using fastai?

Thanks again guys for your genuine help and kind support.
“The best way to learn the subject is to teach people about the subject” :slight_smile:

Himanshu

0 Likes

(Hao) #12

For Siamese network, I have implemented a sample code, here is the link:

You can also check Radek’s implementation, most of my computer vision study path is following his idea. (classification, multi-lable classification, imbalanced class classification, object detection…etc)

No, I am not very sure about ordinal regression. But here is what I thought about it.

  1. If my understanding is right, ordinal regression is predicating scale, like 1-5 for ranking.
    Despite of data and model, I will first try using different loss function. If we only predicting given set range of numbers, cross-entropy seems ideal for the case. Where each of the problem only has one and only one result, therefore, you can treat it with classification problem and use cross-entropy to make prediction.

  2. Or, you can treat it as classic regression model, use MSE or L1smooth for the loss function to predict it as a regression problem.

Compare the two results on your validation set / test set, figure out if their accuracy are close? or one is better than the other. Ensemble the two model base on the estimated weight (5:5, or 3:7 or something else) see if the final model gives you better result.

The above method is just the general approach I will try for Kaggle competition or something else. Since we are talking about method without first looking at the dataset, it might not be accurate.

Hopefully it points you directions. :slight_smile:

Best,

0 Likes

(Himanshu) #13

Thanks @heye0507. I will check this out and get back.

0 Likes