learn.TTA(is_test=True) not supporting half precision models?

I ran into following error when trying to do learn.TTA(is_test=True).
‘RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same’
It looks like the input from learn.data.test_dl is not converted to half precision.

I tried to add a CallbackHandler([mp_cb]) to the validate() call as cb_handler under get_preds() in tta.py, where mp_cb is the MixedPrecision in learn.callbacks, hoping the callback handler would convert test images data to fp16, but the error is still there.

Can someone please advise if the observation is true, and suggest how TTA on test dataset can be done with fp16? Thanks

1 Like

This is how I do so far… Maybe there is a better way…

  1. Train your Model on FP16
  2. Save Weights
  3. Create new Learner (not FP16, this time)
  4. Load saved weights
  5. Make predictions

There is a better way :wink:
Just type data.train_dl.add_tfm(to_half) to have your test dataloader converting the tensors to half precision. I’ll add this in the MixedPrecision callback so that the bug is fixed.


awesome thanks!

Running learn.validate(data.valid_dl.add_tfm(to_half)) works but when I run learn.TTA(data.valid_dl.add_tfm(to_half)) I get the following error:

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/functional.py in softmax(input, dim, _stacklevel)
    982     if dim is None:
    983         dim = _get_softmax_dim('softmax', input.dim(), _stacklevel)
--> 984     return input.softmax(dim)

RuntimeError: softmax is not implemented for type torch.HalfTensor

Creating a new learner without FP16 also creates an error:

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/conv.py in forward(self, input)
    311     def forward(self, input):
    312         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 313                         self.padding, self.dilation, self.groups)

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

Am I missing something?
Did something change in the library?
Is there also a trick for ClassificationInterpretation.from_learner(learn)?

Kind regards

To use TTA you have to average the probabilities over different augmented inputs, so you need to transforms the last activation into the probabilities, that’s why there is a softmax there. You should compute it in full precision anyway, to avoid any numerical instability.

My advice would be to load your model in a clean learner in full precision for this.

1 Like

how do you do this?
just creating a model and loading your weigths (that are half tensors) does not work.
Do you have a trick like learn.data.valid_dl.add_tfm(to_half) but to transform everything back to float32?

A clean learner, just loading half weights transform the output in half tensors.

learn32 = Learner(data, arch , metrics=[accuracy_thresh, f1])
p_v, t_v = learn32.get_preds()

I think it’s as simple as learn.model.float().


It worked without the learn.model.float() but I got this strange error at the beginning:

Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

After checking the models modules weight types with

learn.model[0][0], learn.model[0][0].weight.type()
Out: (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),

and finding out that they were already of type torch.cuda.FloatTensor I just tried to recreate the learner with a newly recreated databunch and it worked!

With this setup I could run learn.get_preds(), learn.validate(), learn.TTA(), and ClassificationInterpretation.from_learner(learn) without problems.

Therefore, it seems like the databunch gets transformed to FP16 when the FP16 learner gets created with it and this is making problems later on with a FP32 learner created with the old databunch?

When using FP16, all your dataloaders are converted in half precision by adding a transform to them:


You can remove it with


Thank you for the explanation, with that everything makes sense. :smiley:

Side question:
Is the FocalLoss() loss function you used a custom loss function or is it an inbuilt FocalLoss() function in fastai / pytorch?

1 Like

Custom. You can use any Pytorch loss function with fastai.

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        self.gamma = gamma
    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))

        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()

        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
return loss.sum(dim=1).mean()

I used above function but I got an error:

LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.


ValueError                                Traceback (most recent call last)

<ipython-input-23-399ce5aa3598> in <module>()
----> 1 learn.lr_find()
      2 learn.recorder.plot(suggestion=True)

5 frames

<ipython-input-22-fca0f4f684a5> in forward(self, input, target)
     11         if not (target.size() == input.size()):
     12             raise ValueError("Target size ({}) must be the same as input size ({})"
---> 13                              .format(target.size(), input.size()))
     15         max_val = (-input).clamp(min=0)

ValueError: Target size (torch.Size([16])) must be the same as input size (torch.Size([16, 100]))

What are you sizes?
Check the target size use:

x,t = next(iter(data.train_dl))
learn.model(x).shape, t.shape

I ran it and got:

(torch.Size([16, 100]), torch.Size([16]))

Thanks, that works for me. That’s super helpful!

Is that learn.to_fp16() and then learn.save() or learn.export() actually save in fp32 format? And load_learner is full fp32?

No, you have to do learn'to_fp32() to cast it back in full precision.