Oversampling in fastai2

This does not answer your question.

But providing an example for folks who want to apply Weighted Cross-Entropy as the loss function for the imbalanced dataset.

Weighted Cross-Entropy

#Get weights based on the class distribution in the training data
def get_weights(dls):
   
    # 0th index would provide the vocab from text
    # 1st index would provide the vocab from classes
    classes = dls.vocab[1]

    #Get label ids from the dataset using map
    #train_lb_ids = L(map(lambda x: x[1], dls.train_ds))
    # Get the actual labels from the label_ids & the vocab
    #train_lbls = L(map(lambda x: classes[x], train_lb_ids))

    #Combine the above into a single
    train_lbls = L(map(lambda x: classes[x[1]], dls.train_ds))
    label_counter = Counter(train_lbls)
    n_most_common_class = max(label_counter.values()); 
    print(f'Occurrences of the most common class {n_most_common_class}')
    
    # Source: https://discuss.pytorch.org/t/what-is-the-weight-values-mean-in-torch-nn-crossentropyloss/11455/9
    weights = [n_most_common_class/v for k, v in label_counter.items() if v > 0]; return weights 

#Get the weights from classification dataloader

weights = get_weights(dls_cls) 
class_weights = torch.FloatTensor(weights).to(dls_cls.device)
learn_cls.loss_func = partial(F.cross_entropy, weight=class_weights)
4 Likes