how do you do this?
just creating a model and loading your weigths (that are half tensors) does not work.
Do you have a trick like learn.data.valid_dl.add_tfm(to_half)
but to transform everything back to float32?
A clean learner, just loading half weights transform the output in half tensors.
learn32 = Learner(data, arch , metrics=[accuracy_thresh, f1])
learn32.loss_func=FocalLoss()
learn32.load('dk_se_64')
p_v, t_v = learn32.get_preds()
p_v.dtype
>>torch.float16