Callback for a multilabel classifier - Can I get the predictions instead of the outputs?

Hello!

I am writing a callback that computes metrics for a NLP multilabel classifier. To get the output of the model in every batch I use:

@dataclass

class CallbackMetric(Callback):
    learn:Learner
    def on_batch_end(self, **kwargs) -> None:
        train = kwargs["train"]
        if not train:
          last_output = kwargs["last_output"]
          last_target = kwargs["last_target"]
          sigmoid_output = torch.sigmoid(last_output)
          greater_output = torch.gt(sigmoid_output, 0.5).type(torch.FloatTensor)

My problem is that I am not sure if the model is using the same activation when I call it via:

learn.predict("example")

Is there a way of getting the predictions, i.e. tensors with only ones and zeroes, in the callback, instead of the outputs, i.e. tensors with any numbers?

Many thanks!

What you did is exactly what predict does in multilabel problems (sigmoid then 0.5 threshold) so you should be good.

1 Like

Perfect! Thank you very much!