GANLearner won't train, "expected scalar type Long but found Float"

Hi, I’m having trouble getting a GANLearner to train. I’ve tried a lot of variations of the generator, critic, and GANLearner parameters, but can’t get it to work. It currently says “expected scalar type Long but found Float.” Does anyone have any idea what I’m missing?

Generator structure:

learn = unet_learner(dls, models.resnet18,
                     n_out=3,
                     metrics=mse,
                     loss_func=MSELossFlat())

Critic structure:

learn = Learner(crit_dls, gan_critic(), loss_func=CrossEntropyLossFlat())

Code:

img = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImage)),
                   get_items = get_image_files,
                   get_y = labelFunc,
                   splitter=RandomSplitter(),
                   item_tfms=Resize(200, 200),
                   batch_tfms=aug_transforms()
               )
input_dls = img.dataloaders(Path("inputPics/"), bs=2, device=torch.device("cpu"))

input_dls.show_batch()

learn_gen = load_learner("gen-pre2")
learn_gen.dls = input_dls
learn_crit = load_learner('/home/luke/Documents/fastai/unet1/GANtrainingFolder/critic-pre2')

switcher = FixedGANSwitcher()

learn = GANLearner.from_learners(learn_gen, learn_crit,switcher=switcher)

learn.fit(1, 1e-3)

Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-5fa11e9d3275> in <module>
----> 1 learn.fit(1, 1e-3)

~/.local/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    204             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    205             self.n_epoch = n_epoch
--> 206             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    207 
    208     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
    195         for epoch in range(self.n_epoch):
    196             self.epoch=epoch
--> 197             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    198 
    199     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
    189 
    190     def _do_epoch(self):
--> 191         self._do_epoch_train()
    192         self._do_epoch_validate()
    193 

~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
    181     def _do_epoch_train(self):
    182         self.dl = self.dls.train
--> 183         self._with_events(self.all_batches, 'train', CancelTrainException)
    184 
    185     def _do_epoch_validate(self, ds_idx=1, dl=None):

~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/.local/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
    159     def all_batches(self):
    160         self.n_iter = len(self.dl)
--> 161         for o in enumerate(self.dl): self.one_batch(*o)
    162 
    163     def _do_one_batch(self):

~/.local/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
    177         self.iter = i
    178         self._split(b)
--> 179         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    180 
    181     def _do_epoch_train(self):

~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
    164         self.pred = self.model(*self.xb)
    165         self('after_pred')
--> 166         if len(self.yb): self.loss = self.loss_func(self.pred, *self.yb)
    167         self('after_loss')
    168         if not self.training or not len(self.yb): return

~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/.local/lib/python3.8/site-packages/fastai/vision/gan.py in forward(self, *args)
     19 
     20     def forward(self, *args):
---> 21         return self.generator(*args) if self.gen_mode else self.critic(*args)
     22 
     23     def switch(self, gen_mode=None):

~/.local/lib/python3.8/site-packages/fastai/vision/gan.py in critic(self, real_pred, input)
    110         fake = self.gan_model.generator(input).requires_grad_(False)
    111         fake_pred = self.gan_model.critic(fake)
--> 112         self.crit_loss = self.crit_loss_func(real_pred, fake_pred)
    113         return self.crit_loss
    114 

~/.local/lib/python3.8/site-packages/fastai/vision/gan.py in _loss_C(real_pred, fake_pred)
    287         ones  = real_pred.new_ones (real_pred.shape[0])
    288         zeros = fake_pred.new_zeros(fake_pred.shape[0])
--> 289         return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2
    290 
    291     return _loss_G, _loss_C

~/.local/lib/python3.8/site-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
     31         if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
     32         if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 33         return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
     34 
     35 # Cell

~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/.local/lib/python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    959 
    960     def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 961         return F.cross_entropy(input, target, weight=self.weight,
    962                                ignore_index=self.ignore_index, reduction=self.reduction)
    963 

~/.local/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2460         tens_ops = (input, target)
   2461         if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-> 2462             return handle_torch_function(
   2463                 cross_entropy, tens_ops, input, target, weight=weight,
   2464                 size_average=size_average, ignore_index=ignore_index, reduce=reduce,

~/.local/lib/python3.8/site-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1058         # Use `public_api` instead of `implementation` so __torch_function__
   1059         # implementations can do equality/identity comparisons.
-> 1060         result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
   1061 
   1062         if result is not NotImplemented:

~/.local/lib/python3.8/site-packages/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
    317 #         if func.__name__[0]!='_': print(func, types, args, kwargs)
    318 #         with torch._C.DisableTorchFunction(): ret = _convert(func(*args, **(kwargs or {})), self.__class__)
--> 319         ret = super().__torch_function__(func, types, args=args, kwargs=kwargs)
    320         if isinstance(ret, TensorBase): ret.set_meta(self, as_copy=True)
    321         return ret

~/.local/lib/python3.8/site-packages/torch/tensor.py in __torch_function__(cls, func, types, args, kwargs)
    993 
    994         with _C.DisableTorchFunction():
--> 995             ret = func(*args, **kwargs)
    996             return _convert(ret, cls)
    997 

~/.local/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2466     if size_average is not None or reduce is not None:
   2467         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2469 
   2470 

~/.local/lib/python3.8/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2262                          .format(input.size(0), target.size(0)))
   2263     if dim == 2:
-> 2264         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2265     elif dim == 4:
   2266         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: expected scalar type Long but found Float

I finally got it to train. I don’t know yet if it’s actually working(I think I need to change the input data), but by changing the loss functions to this below, it actually completes an epoch.

learn_gen.loss_func = MSELossFlat()
learn_crit.loss_func = LabelSmoothingCrossEntropyFlat()