This does not answer your question.
But providing an example for folks who want to apply Weighted Cross-Entropy as the loss function for the imbalanced dataset.
Weighted Cross-Entropy
#Get weights based on the class distribution in the training data
def get_weights(dls):
# 0th index would provide the vocab from text
# 1st index would provide the vocab from classes
classes = dls.vocab[1]
#Get label ids from the dataset using map
#train_lb_ids = L(map(lambda x: x[1], dls.train_ds))
# Get the actual labels from the label_ids & the vocab
#train_lbls = L(map(lambda x: classes[x], train_lb_ids))
#Combine the above into a single
train_lbls = L(map(lambda x: classes[x[1]], dls.train_ds))
label_counter = Counter(train_lbls)
n_most_common_class = max(label_counter.values());
print(f'Occurrences of the most common class {n_most_common_class}')
# Source: https://discuss.pytorch.org/t/what-is-the-weight-values-mean-in-torch-nn-crossentropyloss/11455/9
weights = [n_most_common_class/v for k, v in label_counter.items() if v > 0]; return weights
#Get the weights from classification dataloader
weights = get_weights(dls_cls)
class_weights = torch.FloatTensor(weights).to(dls_cls.device)
learn_cls.loss_func = partial(F.cross_entropy, weight=class_weights)