Tabular with multiple dependent variables

I attempted to adapt a tabular learning example to use two dependent variables:

from fastai.imports import *
from fastai.tabular.all import *
np.set_printoptions(linewidth=130)

path = untar_data(URLs.ADULT_SAMPLE)
    
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names=['salary','marital-status'],
        cat_names = ['workclass', 'education', 'occupation',
                     'relationship', 'race'],
        cont_names = ['age', 'fnlwgt', 'education-num'],
        procs = [Categorify, FillMissing, Normalize])
    
learn = tabular_learner(dls, metrics=accuracy)
learn.lr_find()

This results in a block size mismatch:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 14
      7 dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names=['salary','marital-status'],
      8         cat_names = ['workclass', 'education', 'occupation',
      9                      'relationship', 'race'],
     10         cont_names = ['age', 'fnlwgt', 'education-num'],
     11         procs = [Categorify, FillMissing, Normalize])
     13 learn = tabular_learner(dls, metrics=accuracy)
---> 14 learn.lr_find()

File /opt/conda/lib/python3.10/site-packages/fastai/callback/schedule.py:293, in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggest_funcs)
    291 n_epoch = num_it//len(self.dls.train) + 1
    292 cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 293 with self.no_logging(): self.fit(n_epoch, cbs=cb)
    294 if suggest_funcs is not None:
    295     lrs, losses = tensor(self.recorder.lrs[num_it//10:-5]), tensor(self.recorder.losses[num_it//10:-5])

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:264, in Learner.fit(self, n_epoch, lr, wd, cbs, reset_opt, start_epoch)
    262 self.opt.set_hypers(lr=self.lr if lr is None else lr)
    263 self.n_epoch = n_epoch
--> 264 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
    198 def _with_events(self, f, event_type, ex, final=noop):
--> 199     try: self(f'before_{event_type}');  f()
    200     except ex: self(f'after_cancel_{event_type}')
    201     self(f'after_{event_type}');  final()

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:253, in Learner._do_fit(self)
    251 for epoch in range(self.n_epoch):
    252     self.epoch=epoch
--> 253     self._with_events(self._do_epoch, 'epoch', CancelEpochException)

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
    198 def _with_events(self, f, event_type, ex, final=noop):
--> 199     try: self(f'before_{event_type}');  f()
    200     except ex: self(f'after_cancel_{event_type}')
    201     self(f'after_{event_type}');  final()

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:247, in Learner._do_epoch(self)
    246 def _do_epoch(self):
--> 247     self._do_epoch_train()
    248     self._do_epoch_validate()

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:239, in Learner._do_epoch_train(self)
    237 def _do_epoch_train(self):
    238     self.dl = self.dls.train
--> 239     self._with_events(self.all_batches, 'train', CancelTrainException)

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
    198 def _with_events(self, f, event_type, ex, final=noop):
--> 199     try: self(f'before_{event_type}');  f()
    200     except ex: self(f'after_cancel_{event_type}')
    201     self(f'after_{event_type}');  final()

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:205, in Learner.all_batches(self)
    203 def all_batches(self):
    204     self.n_iter = len(self.dl)
--> 205     for o in enumerate(self.dl): self.one_batch(*o)

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:235, in Learner.one_batch(self, i, b)
    233 b = self._set_device(b)
    234 self._split(b)
--> 235 self._with_events(self._do_one_batch, 'batch', CancelBatchException)

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
    198 def _with_events(self, f, event_type, ex, final=noop):
--> 199     try: self(f'before_{event_type}');  f()
    200     except ex: self(f'after_cancel_{event_type}')
    201     self(f'after_{event_type}');  final()

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:219, in Learner._do_one_batch(self)
    217 self('after_pred')
    218 if len(self.yb):
--> 219     self.loss_grad = self.loss_func(self.pred, *self.yb)
    220     self.loss = self.loss_grad.clone()
    221 self('after_loss')

File /opt/conda/lib/python3.10/site-packages/fastai/losses.py:54, in BaseLoss.__call__(self, inp, targ, **kwargs)
     52 if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
     53 if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 54 return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/loss.py:1174, in CrossEntropyLoss.forward(self, input, target)
   1173 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1174     return F.cross_entropy(input, target, weight=self.weight,
   1175                            ignore_index=self.ignore_index, reduction=self.reduction,
   1176                            label_smoothing=self.label_smoothing)

File /opt/conda/lib/python3.10/site-packages/torch/nn/functional.py:3015, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   2949 r"""This criterion computes the cross entropy loss between input logits and target.
   2950 
   2951 See :class:`~torch.nn.CrossEntropyLoss` for details.
   (...)
   3012     >>> loss.backward()
   3013 """
   3014 if has_torch_function_variadic(input, target, weight):
-> 3015     return handle_torch_function(
   3016         cross_entropy,
   3017         (input, target, weight),
   3018         input,
   3019         target,
   3020         weight=weight,
   3021         size_average=size_average,
   3022         ignore_index=ignore_index,
   3023         reduce=reduce,
   3024         reduction=reduction,
   3025         label_smoothing=label_smoothing,
   3026     )
   3027 if size_average is not None or reduce is not None:
   3028     reduction = _Reduction.legacy_get_string(size_average, reduce)

File /opt/conda/lib/python3.10/site-packages/torch/overrides.py:1551, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1545     warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
   1546                   "will be an error in future, please define it as a classmethod.",
   1547                   DeprecationWarning)
   1549 # Use `public_api` instead of `implementation` so __torch_function__
   1550 # implementations can do equality/identity comparisons.
-> 1551 result = torch_func_method(public_api, types, args, kwargs)
   1553 if result is not NotImplemented:
   1554     return result

File /opt/conda/lib/python3.10/site-packages/fastai/torch_core.py:382, in TensorBase.__torch_function__(cls, func, types, args, kwargs)
    380 if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
    381 if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)
--> 382 res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
    383 dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))
    384 if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)

File /opt/conda/lib/python3.10/site-packages/torch/_tensor.py:1295, in Tensor.__torch_function__(cls, func, types, args, kwargs)
   1292     return NotImplemented
   1294 with _C.DisableTorchFunctionSubclass():
-> 1295     ret = func(*args, **kwargs)
   1296     if func in get_default_nowrap_functions():
   1297         return ret

File /opt/conda/lib/python3.10/site-packages/torch/nn/functional.py:3029, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3027 if size_average is not None or reduce is not None:
   3028     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3029 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

ValueError: Expected input batch_size (64) to match target batch_size (128).

Is this user error or a bug? If the former, can someone give an example of how to do this properly?

Thanks

If I replace loss_func, flatten_check() fails.