Cannot wrap my head around custom loss

I have a multilabel classification task and want to try a custom loss function. What’s the simplest way?

def aeloss(output:Tensor, target:Tensor) -> Tensor:
    output = output>0
    correct_pos = output.logical_and(target).int().sum()
    incorrect = output.logical_xor(target).int().sum()
    return incorrect / (correct_pos + incorrect)
class AELoss(Module):
    def forward(self, inp, targ): return aeloss(inp,targ)
    def activation(self, out): return F.sigmoid(out)
    def decodes(self, out): return out>0.5

learner = vision_learner(dls, 'convnext_tiny_in22k', loss_func=aeloss).to_fp16()
#learner = vision_learner(dls, 'convnext_tiny_in22k', loss_func=AELoss()).to_fp16()

Every time I try to pass the class or function directly, I get

File ~/mambaforge/lib/python3.9/site-packages/fastai/, in Learner._do_one_batch(self)
    210 self('after_loss')
    211 if not or not len(self.yb): return
--> 212 self._with_events(self._backward, 'backward', CancelBackwardException)
    213 self._with_events(self._step, 'step', CancelStepException)
    214 self.opt.zero_grad()

File ~/mambaforge/lib/python3.9/site-packages/fastai/, in Learner._with_events(self, f, event_type, ex, final)
    192 def _with_events(self, f, event_type, ex, final=noop):
--> 193     try: self(f'before_{event_type}');  f()
    194     except ex: self(f'after_cancel_{event_type}')
    195     self(f'after_{event_type}');  final()

File ~/mambaforge/lib/python3.9/site-packages/fastai/, in Learner._backward(self)
--> 201 def _backward(self): self.loss_grad.backward()

File ~/mambaforge/lib/python3.9/site-packages/torch/, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    341 r"""Computes the gradient of current tensor w.r.t. graph leaves.
    343 The graph is differentiated using the chain rule. If the tensor is
    385         used to compute the attr::tensors.
    386 """
    387 if has_torch_function_unary(self):
--> 388     return handle_torch_function(
    389         Tensor.backward,
    390         (self,),
    391         self,
    392         gradient=gradient,
    393         retain_graph=retain_graph,
    394         create_graph=create_graph,
    395         inputs=inputs)
    396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File ~/mambaforge/lib/python3.9/site-packages/torch/, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1492     warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
   1493                   "will be an error in future, please define it as a classmethod.",
   1494                   DeprecationWarning)
   1496 # Use `public_api` instead of `implementation` so __torch_function__
   1497 # implementations can do equality/identity comparisons.
-> 1498 result = torch_func_method(public_api, types, args, kwargs)
   1500 if result is not NotImplemented:
   1501     return result

File ~/mambaforge/lib/python3.9/site-packages/fastai/, in TensorBase.__torch_function__(cls, func, types, args, kwargs)
    374 if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
    375 if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)
--> 376 res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
    377 dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))
    378 if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)

File ~/mambaforge/lib/python3.9/site-packages/torch/, in Tensor.__torch_function__(cls, func, types, args, kwargs)
   1118     return NotImplemented
   1120 with _C.DisableTorchFunction():
-> 1121     ret = func(*args, **kwargs)
--> 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175     allow_unreachable=True, accumulate_grad=True)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn