I have trained a 5 class classifier. I want to use the raw logits from the learn.predict() which I want to further feed into a regression model. How do I access the raw logits?
learn.model(input_tensor)?
.
Will it give logits during prediction??
Also you need to apply some transformations and normalization if it is an image classifier. If you want to make predictions for a single image then use:
a = ItemBase(input_tensor) # 3d tensor (ch, height, width)
learn.predict(a)[2]
This will give raw predictions. ItemBase is just a wrapper that will help you with transformations and normalization.
learn.predict(a)[2] gives the class probabilities which are calculated post soft_max. I wanted to access the logits which are pre soft_max.
Then, I guess this is the only choice.
learn.model.eval()
input_tensor = ... #normalized
with torch.no_grad():
logits = learn.model(input_tensor)
Make sure you don’t have any softmax layer in the model.
Thanks @rohit_gr
I just add a little thing. I recommend you create the input tensor with learn.data.one_item
to make sure you get all transformations right.
Here is an example for a text classifier:
s = "Example string to classify"
# create a batch with one x item
batch = learn.data.one_item(s)[0]
learn.model.eval()
with torch.no_grad():
logits = learn.model(batch)
The one_item
method and the data
attribute unfortunately are not available in v2. I was able to get the logits by replacing the second line of code with:
batch = learn.dls.test_dl([s]).one_batch()[0]