It worked without the learn.model.float()
but I got this strange error at the beginning:
Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
After checking the models modules weight types with
learn.model[0][0], learn.model[0][0].weight.type()
Out: (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
'torch.cuda.FloatTensor')
and finding out that they were already of type torch.cuda.FloatTensor
I just tried to recreate the learner with a newly recreated databunch and it worked!
With this setup I could run learn.get_preds()
, learn.validate()
, learn.TTA()
, and ClassificationInterpretation.from_learner(learn)
without problems.
Therefore, it seems like the databunch gets transformed to FP16 when the FP16 learner gets created with it and this is making problems later on with a FP32 learner created with the old databunch?