Accessing raw logits from learn.predict()

I have trained a 5 class classifier. I want to use the raw logits from the learn.predict() which I want to further feed into a regression model. How do I access the raw logits?


1 Like

Will it give logits during prediction??

Also you need to apply some transformations and normalization if it is an image classifier. If you want to make predictions for a single image then use:

a = ItemBase(input_tensor) # 3d tensor (ch, height, width)

This will give raw predictions. ItemBase is just a wrapper that will help you with transformations and normalization.

learn.predict(a)[2] gives the class probabilities which are calculated post soft_max. I wanted to access the logits which are pre soft_max.

Then, I guess this is the only choice.

input_tensor = ... #normalized
with torch.no_grad():
    logits = learn.model(input_tensor)

Make sure you don’t have any softmax layer in the model.

Thanks @rohit_gr
I just add a little thing. I recommend you create the input tensor with to make sure you get all transformations right.
Here is an example for a text classifier:

s = "Example string to classify"

# create a batch with one x item
batch =[0]

with torch.no_grad():
    logits = learn.model(batch)

The one_item method and the data attribute unfortunately are not available in v2. I was able to get the logits by replacing the second line of code with:

batch = learn.dls.test_dl([s]).one_batch()[0]