IndexError with Focal Loss

I’m trying to use Focal Loss for an imbalanced binary classification problem with ULMFIT, the code is essentially

class FocalLoss(nn.Module):
    def __init__(self, alpha=1., gamma=1.):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets, **kwargs):
        CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * ((1-pt)**self.gamma) * CE_loss
        return F_loss.mean()

class WeightedFocalLoss(nn.Module):
    def __init__(self, alpha=.25, gamma=2.):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets, **kwargs):
        CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * ((1-pt)**self.gamma) * CE_loss
        return F_loss.mean()

learn = text_classifier_learner(data_clas, 
                                arch = AWD_LSTM,
                                loss_func = WeightedFocalLoss(),

                                drop_mult=0.3,
                                metrics = [accuracy, f1]).to_fp16()
learn.load_encoder('lm_encoder')
...

when trying to view my predictions on validation data, I used

learn.load('final_model')
learn.data.add_test(df_val['text'])
preds, y, losses = learn.get_preds(ds_type=DatasetType.Test, ordered=True, with_loss = True)

I get

/usr/local/lib/python3.6/dist-packages/fastai/text/learner.py in <listcomp>(.0)
     93             sampler = [i for i in self.dl(ds_type).sampler]
     94             reverse_sampler = np.argsort(sampler)
---> 95             preds = [p[reverse_sampler] for p in preds]
     96         return preds
     97 

IndexError: too many indices for tensor of dimension 0

but I get no issues with preds, y = learn.get_preds(ds_type=DatasetType.Test, ordered=True), or if I change the loss function to learn.loss_func = FlattenedLoss(LabelSmoothingCrossEntropy, axis=-1)

It should be no issue when I’m actually making predictions, but I would like to see the losses so I can evaluate the results with TextClassificationInterpretation to look at the top losses, because that also gives an IndexError: too many indices for tensor of dimension 0

1 Like