[HELP] Cannot train Unet learner with half precision

(Kennedy Oung) #1

Hi I’ve been trying to train a super resolution model (similar to lesson 7 of Jeremy’s Practical Deep Learning for Coders course) with half precision.
I’ve managed to train the generator and critic, and am now at the step where I implement the GAN learner.
As it takes a lot of memory I’ve been trying to use half precision for this step, but I’ve been facing problems.
I’ve been getting the same runtime error

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same

Here is the how I implement the half precision.

learn_crit = create_critic_learner(data_crit, metrics = None).load(‘critic-pre128to512-lc’).to_fp16()
learn_gen = create_gen_learner(data_gen).load(‘1b-128to512-lc’).to_fp16()
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=True, switcher=switcher,
opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd).to_fp16()
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
lr = 1e-4
learn.fit(10,lr)
learn_gen.save(‘3a-enhance-lc’)

Can anyone give me tips to make this learner work? :slight_smile:

0 Likes

(Navid Panchi) #2

Seems like your data is still full precision floating point, and your models have half-precision weights.

Try training these as well with half-precision, save those and load them in later on. See if that helps.
In short try using half-precision everywhere, not just in the last step.

0 Likes