I ran into following error when trying to do learn.TTA(is_test=True).
‘RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same’
It looks like the input from learn.data.test_dl is not converted to half precision.
I tried to add a CallbackHandler([mp_cb]) to the validate() call as cb_handler under get_preds() in tta.py, where mp_cb is the MixedPrecision in learn.callbacks, hoping the callback handler would convert test images data to fp16, but the error is still there.
Can someone please advise if the observation is true, and suggest how TTA on test dataset can be done with fp16? Thanks