Cannot wrap my head around custom loss

I have a multilabel classification task and want to try a custom loss function. What’s the simplest way?

def aeloss(output:Tensor, target:Tensor) -> Tensor:
    output = output>0
    correct_pos = output.logical_and(target).int().sum()
    incorrect = output.logical_xor(target).int().sum()
    return incorrect / (correct_pos + incorrect)
class AELoss(Module):
    def forward(self, inp, targ): return aeloss(inp,targ)
    def activation(self, out): return F.sigmoid(out)
    def decodes(self, out): return out>0.5

learner = vision_learner(dls, 'convnext_tiny_in22k', loss_func=aeloss).to_fp16()
#learner = vision_learner(dls, 'convnext_tiny_in22k', loss_func=AELoss()).to_fp16()

Every time I try to pass the class or function directly, I get

File ~/mambaforge/lib/python3.9/site-packages/fastai/learner.py:212, in Learner._do_one_batch(self)
    210 self('after_loss')
    211 if not self.training or not len(self.yb): return
--> 212 self._with_events(self._backward, 'backward', CancelBackwardException)
    213 self._with_events(self._step, 'step', CancelStepException)
    214 self.opt.zero_grad()

File ~/mambaforge/lib/python3.9/site-packages/fastai/learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
    192 def _with_events(self, f, event_type, ex, final=noop):
--> 193     try: self(f'before_{event_type}');  f()
    194     except ex: self(f'after_cancel_{event_type}')
    195     self(f'after_{event_type}');  final()

File ~/mambaforge/lib/python3.9/site-packages/fastai/learner.py:201, in Learner._backward(self)
--> 201 def _backward(self): self.loss_grad.backward()

File ~/mambaforge/lib/python3.9/site-packages/torch/_tensor.py:388, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    341 r"""Computes the gradient of current tensor w.r.t. graph leaves.
    342 
    343 The graph is differentiated using the chain rule. If the tensor is
   (...)
    385         used to compute the attr::tensors.
    386 """
    387 if has_torch_function_unary(self):
--> 388     return handle_torch_function(
    389         Tensor.backward,
    390         (self,),
    391         self,
    392         gradient=gradient,
    393         retain_graph=retain_graph,
    394         create_graph=create_graph,
    395         inputs=inputs)
    396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File ~/mambaforge/lib/python3.9/site-packages/torch/overrides.py:1498, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1492     warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
   1493                   "will be an error in future, please define it as a classmethod.",
   1494                   DeprecationWarning)
   1496 # Use `public_api` instead of `implementation` so __torch_function__
   1497 # implementations can do equality/identity comparisons.
-> 1498 result = torch_func_method(public_api, types, args, kwargs)
   1500 if result is not NotImplemented:
   1501     return result

File ~/mambaforge/lib/python3.9/site-packages/fastai/torch_core.py:376, in TensorBase.__torch_function__(cls, func, types, args, kwargs)
    374 if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
    375 if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)
--> 376 res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
    377 dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))
    378 if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)

File ~/mambaforge/lib/python3.9/site-packages/torch/_tensor.py:1121, in Tensor.__torch_function__(cls, func, types, args, kwargs)
   1118     return NotImplemented
   1120 with _C.DisableTorchFunction():
-> 1121     ret = func(*args, **kwargs)
...
--> 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175     allow_unreachable=True, accumulate_grad=True)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn