I am creating a VAE based on Lecture 17 in stat453-deep-learning-ss21
Note that in the linked reference they deal directly with the batch dimension whereas this is handled by the losses’ reduction param.
A lot of classes used by fastai expect that the forward function returns a single tensor, in the appended call stack an tuple raises the exception when treated like a tensor.
In this case a 4-tuple of tensors are returned, my custom loss class provided to the learner unpacks the input arg:
learner = Learner(dls, VAE(), cbs=BertLoss(0.3), metrics=rmse, loss_func=VAELoss(flatten=False))
Is this a viable formulation or how can I refactor to work with my VAE?
The return from model.forward():
return encoded, z_mean, z_log_var, decoded
Thanks, apologies if duplicated.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<proj_path>\10_dataloader.ipynb Cell 29 in <cell line: 95>()
92 learner = Learner(dls, VAE(), cbs=BertLoss(0.3), metrics=rmse, loss_func=VAELoss(flatten=False))
93 # learner = Learner(dls, VAE(), cbs=BertLoss(0.3), metrics=rmse)
94 # learner.fit_one_cycle(3, cbs=[ShowGraphCallback(), ])
---> 95 learner.fit_one_cycle(20, cbs=[ShowGraphCallback(), EarlyStoppingCallback(patience=3, min_delta=1)])
File <conda_path>\envs\tse\lib\site-packages\fastai\callback\schedule.py:119, in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt, start_epoch)
116 lr_max = np.array([h['lr'] for h in self.opt.hypers])
117 scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
118 'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 119 self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=start_epoch)
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:256, in Learner.fit(self, n_epoch, lr, wd, cbs, reset_opt, start_epoch)
254 self.opt.set_hypers(lr=self.lr if lr is None else lr)
255 self.n_epoch = n_epoch
--> 256 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
192 def _with_events(self, f, event_type, ex, final=noop):
--> 193 try: self(f'before_{event_type}'); f()
194 except ex: self(f'after_cancel_{event_type}')
195 self(f'after_{event_type}'); final()
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:245, in Learner._do_fit(self)
243 for epoch in range(self.n_epoch):
244 self.epoch=epoch
--> 245 self._with_events(self._do_epoch, 'epoch', CancelEpochException)
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
192 def _with_events(self, f, event_type, ex, final=noop):
--> 193 try: self(f'before_{event_type}'); f()
194 except ex: self(f'after_cancel_{event_type}')
195 self(f'after_{event_type}'); final()
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:240, in Learner._do_epoch(self)
238 def _do_epoch(self):
239 self._do_epoch_train()
--> 240 self._do_epoch_validate()
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:236, in Learner._do_epoch_validate(self, ds_idx, dl)
234 if dl is None: dl = self.dls[ds_idx]
235 self.dl = dl
--> 236 with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
192 def _with_events(self, f, event_type, ex, final=noop):
--> 193 try: self(f'before_{event_type}'); f()
194 except ex: self(f'after_cancel_{event_type}')
195 self(f'after_{event_type}'); final()
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:199, in Learner.all_batches(self)
197 def all_batches(self):
198 self.n_iter = len(self.dl)
--> 199 for o in enumerate(self.dl): self.one_batch(*o)
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:227, in Learner.one_batch(self, i, b)
225 b = self._set_device(b)
226 self._split(b)
--> 227 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:195, in Learner._with_events(self, f, event_type, ex, final)
193 try: self(f'before_{event_type}'); f()
194 except ex: self(f'after_cancel_{event_type}')
--> 195 self(f'after_{event_type}'); final()
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:171, in Learner.__call__(self, event_name)
--> 171 def __call__(self, event_name): L(event_name).map(self._call_one)
File <conda_path>\envs\tse\lib\site-packages\fastcore\foundation.py:156, in L.map(self, f, gen, *args, **kwargs)
--> 156 def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
File <conda_path>\envs\tse\lib\site-packages\fastcore\basics.py:835, in map_ex(iterable, f, gen, *args, **kwargs)
833 res = map(g, iterable)
834 if gen: return res
--> 835 return list(res)
File <conda_path>\envs\tse\lib\site-packages\fastcore\basics.py:820, in bind.__call__(self, *args, **kwargs)
818 if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
819 fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 820 return self.func(*fargs, **kwargs)
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:175, in Learner._call_one(self, event_name)
173 def _call_one(self, event_name):
174 if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 175 for cb in self.cbs.sorted('order'): cb(event_name)
File <conda_path>\envs\tse\lib\site-packages\fastai\callback\core.py:62, in Callback.__call__(self, event_name)
60 try: res = getcallable(self, event_name)()
61 except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
---> 62 except Exception as e: raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)
63 if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
64 return res
File <conda_path>\envs\tse\lib\site-packages\fastai\callback\core.py:60, in Callback.__call__(self, event_name)
58 res = None
59 if self.run and _run:
---> 60 try: res = getcallable(self, event_name)()
61 except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
62 except Exception as e: raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)
File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:541, in Recorder.after_batch(self)
539 if len(self.yb) == 0: return
540 mets = self._train_mets if self.training else self._valid_mets
--> 541 for met in mets: met.accumulate(self.learn)
542 if not self.training: return
543 self.lrs.append(self.opt.hypers[-1]['lr'])
File <conda_path>\envs\tse\lib\site-packages\fastai\metrics.py:48, in AccumMetric.accumulate(self, learn)
46 elif self.dim_argmax: pred = pred.argmax(dim=self.dim_argmax)
47 if self.thresh: pred = (pred >= self.thresh)
---> 48 self.accum_values(pred,learn.y,learn)
File <conda_path>\envs\tse\lib\site-packages\fastai\metrics.py:54, in AccumMetric.accum_values(self, preds, targs, learn)
52 to_d = learn.to_detach if learn is not None else to_detach
53 preds,targs = to_d(preds),to_d(targs)
---> 54 if self.flatten: preds,targs = flatten_check(preds,targs)
55 self.preds.append(preds)
56 self.targs.append(targs)
File <conda_path>\envs\tse\lib\site-packages\fastai\torch_core.py:759, in flatten_check(inp, targ)
757 def flatten_check(inp, targ):
758 "Check that `out` and `targ` have the same number of elements and flatten them."
--> 759 inp,targ = TensorBase(inp.contiguous()).view(-1),TensorBase(targ.contiguous()).view(-1)
760 test_eq(len(inp), len(targ))
761 return inp,targ
AttributeError: Exception occured in `Recorder` when calling event `after_batch`:
'tuple' object has no attribute 'contiguous'