Getting predictions get_preds() for test sets


I have posted this on another part but perhaps it is too obvious but I can’t seem to find any answer to this trivial question.

If I run val_preds,val_targets = learn.get_preds() I get the predictions and targets for the validation set. How do I run it for the test set?

So far all the docs show learn.get_preds(is_test=True) but is_test has been changed to something else. Please advise !

Cheers, Hud


i dont know for one liner

but its work fo me

preds = []
for i in range(0,30):
    p = learn.predict(data.test_ds.x[i])
1 Like

Hello, with v1 you can pass in the DatasetType to get_preds().

preds, y = learn.get_preds(DatasetType.Test)


Thanks this sort of worked, but now new problems arise.

I followed the solutions for vision in the inference tutorials and am able to predict different classes of boats from single images from the test set. However, images from a test_set folder returns only one class.

Have a look at the inference section in my gist-notebook

Has anyone had this problem?

I have the same problem. I’m just a beginner so I don’t know why, but my workaround is:

preds, y, losses = learn.get_preds(ds_type=DatasetType.Test, with_loss=True)
y = torch.argmax(preds, dim=1)

Thanks for this workaround! Works great for me too.

I also am seeing the y predictions returning all zeros, in my case using a text learner (in fastai v1.0.48):

preds, y = learn.get_preds(ds_type=DatasetType.Test) print(y)

tensor([0, 0, 0, ..., 0, 0, 0])

Even though taking the argmax reveals that not all the highest preds are at label position 0.

@sgugger Is this a known issue?

That’s not an issue, it’s because the test set in fastai is unlabelled, so the y (your targets) are all set to zero.


I see. That explains it. I thought y was the predicted y, but it’s actually the true y (which we don’t have in a test set). Thanks.

1 Like

Are there plans to change this? For me it always feels kind of hacky to artificially add none labels to the test data loader in order for it to work.

Not in the midterm, no, we have a lot on our plates with the ongoing course right now.

I am facing the same issue, i am getting image class as 0 using learn,get_preds(),however my class range is 1-5.Please suggest solution to incorrect class prediction .

@sariabod-For test test prediction we have to create new data using databunch and instead of valid we have replace Test folder in the argument.then data ,finally we can make prediction on that using your code.This was suggested in this Forum .But Still Its not working .Prediction for test class labels is not supported directly in fastai.

Hi Abhi,
Try this :

predictions, *_ = learner.get_preds(DatasetType.Test)
labels = np.argmax(predictions, 1)

You can also check this thread for additional info.

Welcome to the forums!

I tried your line of code,but i amstill getting labels which are not in my class list also.

@abhi891, you’re only getting zeros, right?
I came across the same problem with tabular.

Give this a try. This worked for me.

preds,_ = learn.get_preds(ds_type=DatasetType.Test)
result = preds.numpy()[:, 0]
1 Like

For my image classification problem with 5 classes below code worked with the help of your link .
preds,_ = learn.get_preds(ds_type=DatasetType.Test)
labels = np.argmax(preds, 1)
test_predictions_direct = [data.classes[int(x)] for x in labels]

Thanks a lot

1 Like

Hi @abhi891,

I downloaded your code from AV competition and ran the above code , but I am stuck at these lines of code.

all_test_preds = []
for i in range(1, 3+1):
learn.load(‘stage-’ + str(i))
probs, y = learn.get_preds(ds_type=DatasetType.Test);

final = [data.classes[i] for i in np.argmax(np.mean(all_test_preds, 0), axis=1)]

This is the error am getting.

Traceback (most recent call last):
File “/usr/lib/python3.6/multiprocessing/”, line 240, in _feed
File “/usr/lib/python3.6/multiprocessing/”, line 200, in send_bytes
self._send_bytes(m[offset:offset + size])
File “/usr/lib/python3.6/multiprocessing/”, line 404, in _send_bytes
self._send(header + buf)
File “/usr/lib/python3.6/multiprocessing/”, line 368, in _send
n = write(self._handle, buf)
OSError: [Errno 9] Bad file descriptor
Traceback (most recent call last):
File “/usr/lib/python3.6/multiprocessing/”, line 230, in _feed
File “/usr/lib/python3.6/multiprocessing/”, line 177, in close
File “/usr/lib/python3.6/multiprocessing/”, line 361, in _close
OSError: [Errno 9] Bad file descriptor…

Any idea why this is? Does this take lot of time to run??

Tried this, I’m getting correct predictions but the order is somehow different. How do I fix the order of the prediction?