Returning Tuples from .forward(): yes or no?

I am creating a VAE based on Lecture 17 in stat453-deep-learning-ss21
Note that in the linked reference they deal directly with the batch dimension whereas this is handled by the losses’ reduction param.

A lot of classes used by fastai expect that the forward function returns a single tensor, in the appended call stack an tuple raises the exception when treated like a tensor.

In this case a 4-tuple of tensors are returned, my custom loss class provided to the learner unpacks the input arg:
learner = Learner(dls, VAE(), cbs=BertLoss(0.3), metrics=rmse, loss_func=VAELoss(flatten=False))

Is this a viable formulation or how can I refactor to work with my VAE?
The return from model.forward():
return encoded, z_mean, z_log_var, decoded

Thanks, apologies if duplicated.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<proj_path>\10_dataloader.ipynb Cell 29 in <cell line: 95>()
     92 learner = Learner(dls, VAE(), cbs=BertLoss(0.3), metrics=rmse, loss_func=VAELoss(flatten=False))
     93 # learner = Learner(dls, VAE(), cbs=BertLoss(0.3), metrics=rmse)
     94 # learner.fit_one_cycle(3, cbs=[ShowGraphCallback(), ])
---> 95 learner.fit_one_cycle(20, cbs=[ShowGraphCallback(), EarlyStoppingCallback(patience=3, min_delta=1)])

File <conda_path>\envs\tse\lib\site-packages\fastai\callback\schedule.py:119, in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt, start_epoch)
    116 lr_max = np.array([h['lr'] for h in self.opt.hypers])
    117 scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    118           'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 119 self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=start_epoch)

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:256, in Learner.fit(self, n_epoch, lr, wd, cbs, reset_opt, start_epoch)
    254 self.opt.set_hypers(lr=self.lr if lr is None else lr)
    255 self.n_epoch = n_epoch
--> 256 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
    192 def _with_events(self, f, event_type, ex, final=noop):
--> 193     try: self(f'before_{event_type}');  f()
    194     except ex: self(f'after_cancel_{event_type}')
    195     self(f'after_{event_type}');  final()

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:245, in Learner._do_fit(self)
    243 for epoch in range(self.n_epoch):
    244     self.epoch=epoch
--> 245     self._with_events(self._do_epoch, 'epoch', CancelEpochException)

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
    192 def _with_events(self, f, event_type, ex, final=noop):
--> 193     try: self(f'before_{event_type}');  f()
    194     except ex: self(f'after_cancel_{event_type}')
    195     self(f'after_{event_type}');  final()

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:240, in Learner._do_epoch(self)
    238 def _do_epoch(self):
    239     self._do_epoch_train()
--> 240     self._do_epoch_validate()

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:236, in Learner._do_epoch_validate(self, ds_idx, dl)
    234 if dl is None: dl = self.dls[ds_idx]
    235 self.dl = dl
--> 236 with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
    192 def _with_events(self, f, event_type, ex, final=noop):
--> 193     try: self(f'before_{event_type}');  f()
    194     except ex: self(f'after_cancel_{event_type}')
    195     self(f'after_{event_type}');  final()

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:199, in Learner.all_batches(self)
    197 def all_batches(self):
    198     self.n_iter = len(self.dl)
--> 199     for o in enumerate(self.dl): self.one_batch(*o)

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:227, in Learner.one_batch(self, i, b)
    225 b = self._set_device(b)
    226 self._split(b)
--> 227 self._with_events(self._do_one_batch, 'batch', CancelBatchException)

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:195, in Learner._with_events(self, f, event_type, ex, final)
    193 try: self(f'before_{event_type}');  f()
    194 except ex: self(f'after_cancel_{event_type}')
--> 195 self(f'after_{event_type}');  final()

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:171, in Learner.__call__(self, event_name)
--> 171 def __call__(self, event_name): L(event_name).map(self._call_one)

File <conda_path>\envs\tse\lib\site-packages\fastcore\foundation.py:156, in L.map(self, f, gen, *args, **kwargs)
--> 156 def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))

File <conda_path>\envs\tse\lib\site-packages\fastcore\basics.py:835, in map_ex(iterable, f, gen, *args, **kwargs)
    833 res = map(g, iterable)
    834 if gen: return res
--> 835 return list(res)

File <conda_path>\envs\tse\lib\site-packages\fastcore\basics.py:820, in bind.__call__(self, *args, **kwargs)
    818     if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    819 fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 820 return self.func(*fargs, **kwargs)

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:175, in Learner._call_one(self, event_name)
    173 def _call_one(self, event_name):
    174     if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 175     for cb in self.cbs.sorted('order'): cb(event_name)

File <conda_path>\envs\tse\lib\site-packages\fastai\callback\core.py:62, in Callback.__call__(self, event_name)
     60     try: res = getcallable(self, event_name)()
     61     except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
---> 62     except Exception as e: raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)
     63 if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     64 return res

File <conda_path>\envs\tse\lib\site-packages\fastai\callback\core.py:60, in Callback.__call__(self, event_name)
     58 res = None
     59 if self.run and _run: 
---> 60     try: res = getcallable(self, event_name)()
     61     except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
     62     except Exception as e: raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)

File <conda_path>\envs\tse\lib\site-packages\fastai\learner.py:541, in Recorder.after_batch(self)
    539 if len(self.yb) == 0: return
    540 mets = self._train_mets if self.training else self._valid_mets
--> 541 for met in mets: met.accumulate(self.learn)
    542 if not self.training: return
    543 self.lrs.append(self.opt.hypers[-1]['lr'])

File <conda_path>\envs\tse\lib\site-packages\fastai\metrics.py:48, in AccumMetric.accumulate(self, learn)
     46 elif self.dim_argmax: pred = pred.argmax(dim=self.dim_argmax)
     47 if self.thresh:  pred = (pred >= self.thresh)
---> 48 self.accum_values(pred,learn.y,learn)

File <conda_path>\envs\tse\lib\site-packages\fastai\metrics.py:54, in AccumMetric.accum_values(self, preds, targs, learn)
     52 to_d = learn.to_detach if learn is not None else to_detach
     53 preds,targs = to_d(preds),to_d(targs)
---> 54 if self.flatten: preds,targs = flatten_check(preds,targs)
     55 self.preds.append(preds)
     56 self.targs.append(targs)

File <conda_path>\envs\tse\lib\site-packages\fastai\torch_core.py:759, in flatten_check(inp, targ)
    757 def flatten_check(inp, targ):
    758     "Check that `out` and `targ` have the same number of elements and flatten them."
--> 759     inp,targ = TensorBase(inp.contiguous()).view(-1),TensorBase(targ.contiguous()).view(-1)
    760     test_eq(len(inp), len(targ))
    761     return inp,targ

AttributeError: Exception occured in `Recorder` when calling event `after_batch`:
	'tuple' object has no attribute 'contiguous'