Hi, I’m having trouble getting a GANLearner to train. I’ve tried a lot of variations of the generator, critic, and GANLearner parameters, but can’t get it to work. It currently says “expected scalar type Long but found Float.” Does anyone have any idea what I’m missing?
Generator structure:
learn = unet_learner(dls, models.resnet18,
n_out=3,
metrics=mse,
loss_func=MSELossFlat())
Critic structure:
learn = Learner(crit_dls, gan_critic(), loss_func=CrossEntropyLossFlat())
Code:
img = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImage)),
get_items = get_image_files,
get_y = labelFunc,
splitter=RandomSplitter(),
item_tfms=Resize(200, 200),
batch_tfms=aug_transforms()
)
input_dls = img.dataloaders(Path("inputPics/"), bs=2, device=torch.device("cpu"))
input_dls.show_batch()
learn_gen = load_learner("gen-pre2")
learn_gen.dls = input_dls
learn_crit = load_learner('/home/luke/Documents/fastai/unet1/GANtrainingFolder/critic-pre2')
switcher = FixedGANSwitcher()
learn = GANLearner.from_learners(learn_gen, learn_crit,switcher=switcher)
learn.fit(1, 1e-3)
Error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-6-5fa11e9d3275> in <module>
----> 1 learn.fit(1, 1e-3)
~/.local/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
204 self.opt.set_hypers(lr=self.lr if lr is None else lr)
205 self.n_epoch = n_epoch
--> 206 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
207
208 def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None
~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
153
154 def _with_events(self, f, event_type, ex, final=noop):
--> 155 try: self(f'before_{event_type}') ;f()
156 except ex: self(f'after_cancel_{event_type}')
157 finally: self(f'after_{event_type}') ;final()
~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
195 for epoch in range(self.n_epoch):
196 self.epoch=epoch
--> 197 self._with_events(self._do_epoch, 'epoch', CancelEpochException)
198
199 def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):
~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
153
154 def _with_events(self, f, event_type, ex, final=noop):
--> 155 try: self(f'before_{event_type}') ;f()
156 except ex: self(f'after_cancel_{event_type}')
157 finally: self(f'after_{event_type}') ;final()
~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
189
190 def _do_epoch(self):
--> 191 self._do_epoch_train()
192 self._do_epoch_validate()
193
~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
181 def _do_epoch_train(self):
182 self.dl = self.dls.train
--> 183 self._with_events(self.all_batches, 'train', CancelTrainException)
184
185 def _do_epoch_validate(self, ds_idx=1, dl=None):
~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
153
154 def _with_events(self, f, event_type, ex, final=noop):
--> 155 try: self(f'before_{event_type}') ;f()
156 except ex: self(f'after_cancel_{event_type}')
157 finally: self(f'after_{event_type}') ;final()
~/.local/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
159 def all_batches(self):
160 self.n_iter = len(self.dl)
--> 161 for o in enumerate(self.dl): self.one_batch(*o)
162
163 def _do_one_batch(self):
~/.local/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
177 self.iter = i
178 self._split(b)
--> 179 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
180
181 def _do_epoch_train(self):
~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
153
154 def _with_events(self, f, event_type, ex, final=noop):
--> 155 try: self(f'before_{event_type}') ;f()
156 except ex: self(f'after_cancel_{event_type}')
157 finally: self(f'after_{event_type}') ;final()
~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
164 self.pred = self.model(*self.xb)
165 self('after_pred')
--> 166 if len(self.yb): self.loss = self.loss_func(self.pred, *self.yb)
167 self('after_loss')
168 if not self.training or not len(self.yb): return
~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~/.local/lib/python3.8/site-packages/fastai/vision/gan.py in forward(self, *args)
19
20 def forward(self, *args):
---> 21 return self.generator(*args) if self.gen_mode else self.critic(*args)
22
23 def switch(self, gen_mode=None):
~/.local/lib/python3.8/site-packages/fastai/vision/gan.py in critic(self, real_pred, input)
110 fake = self.gan_model.generator(input).requires_grad_(False)
111 fake_pred = self.gan_model.critic(fake)
--> 112 self.crit_loss = self.crit_loss_func(real_pred, fake_pred)
113 return self.crit_loss
114
~/.local/lib/python3.8/site-packages/fastai/vision/gan.py in _loss_C(real_pred, fake_pred)
287 ones = real_pred.new_ones (real_pred.shape[0])
288 zeros = fake_pred.new_zeros(fake_pred.shape[0])
--> 289 return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2
290
291 return _loss_G, _loss_C
~/.local/lib/python3.8/site-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
31 if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
32 if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 33 return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
34
35 # Cell
~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~/.local/lib/python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
959
960 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 961 return F.cross_entropy(input, target, weight=self.weight,
962 ignore_index=self.ignore_index, reduction=self.reduction)
963
~/.local/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2460 tens_ops = (input, target)
2461 if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-> 2462 return handle_torch_function(
2463 cross_entropy, tens_ops, input, target, weight=weight,
2464 size_average=size_average, ignore_index=ignore_index, reduce=reduce,
~/.local/lib/python3.8/site-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
1058 # Use `public_api` instead of `implementation` so __torch_function__
1059 # implementations can do equality/identity comparisons.
-> 1060 result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
1061
1062 if result is not NotImplemented:
~/.local/lib/python3.8/site-packages/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
317 # if func.__name__[0]!='_': print(func, types, args, kwargs)
318 # with torch._C.DisableTorchFunction(): ret = _convert(func(*args, **(kwargs or {})), self.__class__)
--> 319 ret = super().__torch_function__(func, types, args=args, kwargs=kwargs)
320 if isinstance(ret, TensorBase): ret.set_meta(self, as_copy=True)
321 return ret
~/.local/lib/python3.8/site-packages/torch/tensor.py in __torch_function__(cls, func, types, args, kwargs)
993
994 with _C.DisableTorchFunction():
--> 995 ret = func(*args, **kwargs)
996 return _convert(ret, cls)
997
~/.local/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2466 if size_average is not None or reduce is not None:
2467 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
2469
2470
~/.local/lib/python3.8/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
2262 .format(input.size(0), target.size(0)))
2263 if dim == 2:
-> 2264 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
2265 elif dim == 4:
2266 ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: expected scalar type Long but found Float