learn.TTA(is_test=True) not supporting half precision models?


(Bruce Yang) #1

Hi,
I ran into following error when trying to do learn.TTA(is_test=True).
‘RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same’
It looks like the input from learn.data.test_dl is not converted to half precision.

I tried to add a CallbackHandler([mp_cb]) to the validate() call as cb_handler under get_preds() in tta.py, where mp_cb is the MixedPrecision in learn.callbacks, hoping the callback handler would convert test images data to fp16, but the error is still there.

Can someone please advise if the observation is true, and suggest how TTA on test dataset can be done with fp16? Thanks


(Asimo) #2

This is how I do so far… Maybe there is a better way…

  1. Train your Model on FP16
  2. Save Weights
  3. Create new Learner (not FP16, this time)
  4. Load saved weights
  5. Make predictions

#3

There is a better way :wink:
Just type data.train_dl.add_tfm(to_half) to have your test dataloader converting the tensors to half precision. I’ll add this in the MixedPrecision callback so that the bug is fixed.


(Asimo) #4

awesome thanks!