I have a multilabel classification task and want to try a custom loss function. What’s the simplest way?
def aeloss(output:Tensor, target:Tensor) -> Tensor:
output = output>0
correct_pos = output.logical_and(target).int().sum()
incorrect = output.logical_xor(target).int().sum()
return incorrect / (correct_pos + incorrect)
class AELoss(Module):
def forward(self, inp, targ): return aeloss(inp,targ)
def activation(self, out): return F.sigmoid(out)
def decodes(self, out): return out>0.5
learner = vision_learner(dls, 'convnext_tiny_in22k', loss_func=aeloss).to_fp16()
#learner = vision_learner(dls, 'convnext_tiny_in22k', loss_func=AELoss()).to_fp16()
Every time I try to pass the class or function directly, I get
File ~/mambaforge/lib/python3.9/site-packages/fastai/learner.py:212, in Learner._do_one_batch(self)
210 self('after_loss')
211 if not self.training or not len(self.yb): return
--> 212 self._with_events(self._backward, 'backward', CancelBackwardException)
213 self._with_events(self._step, 'step', CancelStepException)
214 self.opt.zero_grad()
File ~/mambaforge/lib/python3.9/site-packages/fastai/learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
192 def _with_events(self, f, event_type, ex, final=noop):
--> 193 try: self(f'before_{event_type}'); f()
194 except ex: self(f'after_cancel_{event_type}')
195 self(f'after_{event_type}'); final()
File ~/mambaforge/lib/python3.9/site-packages/fastai/learner.py:201, in Learner._backward(self)
--> 201 def _backward(self): self.loss_grad.backward()
File ~/mambaforge/lib/python3.9/site-packages/torch/_tensor.py:388, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
341 r"""Computes the gradient of current tensor w.r.t. graph leaves.
342
343 The graph is differentiated using the chain rule. If the tensor is
(...)
385 used to compute the attr::tensors.
386 """
387 if has_torch_function_unary(self):
--> 388 return handle_torch_function(
389 Tensor.backward,
390 (self,),
391 self,
392 gradient=gradient,
393 retain_graph=retain_graph,
394 create_graph=create_graph,
395 inputs=inputs)
396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File ~/mambaforge/lib/python3.9/site-packages/torch/overrides.py:1498, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
1492 warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
1493 "will be an error in future, please define it as a classmethod.",
1494 DeprecationWarning)
1496 # Use `public_api` instead of `implementation` so __torch_function__
1497 # implementations can do equality/identity comparisons.
-> 1498 result = torch_func_method(public_api, types, args, kwargs)
1500 if result is not NotImplemented:
1501 return result
File ~/mambaforge/lib/python3.9/site-packages/fastai/torch_core.py:376, in TensorBase.__torch_function__(cls, func, types, args, kwargs)
374 if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
375 if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)
--> 376 res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
377 dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))
378 if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)
File ~/mambaforge/lib/python3.9/site-packages/torch/_tensor.py:1121, in Tensor.__torch_function__(cls, func, types, args, kwargs)
1118 return NotImplemented
1120 with _C.DisableTorchFunction():
-> 1121 ret = func(*args, **kwargs)
...
--> 173 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
174 tensors, grad_tensors_, retain_graph, create_graph, inputs,
175 allow_unreachable=True, accumulate_grad=True)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn