Hello!
I am writing a callback that computes metrics for a NLP multilabel classifier. To get the output of the model in every batch I use:
@dataclass
class CallbackMetric(Callback):
learn:Learner
def on_batch_end(self, **kwargs) -> None:
train = kwargs["train"]
if not train:
last_output = kwargs["last_output"]
last_target = kwargs["last_target"]
sigmoid_output = torch.sigmoid(last_output)
greater_output = torch.gt(sigmoid_output, 0.5).type(torch.FloatTensor)
My problem is that I am not sure if the model is using the same activation when I call it via:
learn.predict("example")
Is there a way of getting the predictions, i.e. tensors with only ones and zeroes, in the callback, instead of the outputs, i.e. tensors with any numbers?
Many thanks!