How to get predictions for all images in a folder?

How can we get predictions for all images in a folder efficiently?

I have trained a model on MNIST and it is working pretty well.
I can use learn.predict to get a prediction on a single image.

I tried looping through the images in the folder and running learn.predict but it was way too slow:

files = !ls "mnist_data/test"
preds = []
for file in tqdm(files):
    number, n_th, probs = learn.predict(f"mnist_data/test/{file}")

So instead I decided to gather the images into a numpy array:

files = !ls "mnist_data/test"
imgs = []
for file in tqdm(files):
    with"mnist_data/test/{file}") as img:

imgs = np.array(imgs) # shape: (28000, 28, 28)

The next step would be to make batches from imgs and get predictions for them. But I am not sure how to get the predictions for a batch (using the same stats transforms etc…).

learn.predict is definitely not the right function for this according to the docs.

get_preds I think is what you’re wanting :slight_smile:

1 Like

Get the predictions and targets on the ds_idx -th dbunchset or dl , optionally with_input and with_loss

Then the question becomes:
How do we make a DataLoader with the same stats as the one we trained on?

There should be an easy way to do this…

I make the DataLoaders like so:

mnist = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),

mnist =, Zoom(), Rotate()))
dls = mnist.dataloaders("mnist_data/train")

We have a test_dl where you pass in the items you want to use as your test set.

For instance here you would do:

test_dl = dls.test_dl(get_image_files('mnist_data/train'))
(where dls is your original dataloader)

and you can then just do learn.get_preds(dl=test_dl)


That is exactly what I was looking for. Thanks so much :smiley:

Here is my solution:

def load_imgs(path):
     image_files = []
     for file in os.listdir(path):
         if file.endswith('.jpg'):

     return image_files

test_path_lily = 'Data/test/1'

test_img_lily = load_imgs(test_path_lily)

uploader_lily = SimpleNamespace(data = test_img_lily)

def do_test(model, uploader, num):
     for i in range(num):
         img = PILImage.create([i])
         predict = model.predict(img)
         print(f"loop {i}, {[i]}, {predict[0]}, {predict[2][0]}, {predict[2][1]}")

do_test(learn, uploader_lily, len(test_img_lily))

This is so elegant!

The problem with this is that it takes more than a day to run on my data. While the way we are supposed to do it (as @muellerzr pointed it out) takes less than a minute.

Yes, my way takes very very long time.

learn.validate() outputs a list with two elements. I have not yet understood what those two numbers mean or when to use the function.
Can anyone explain this in more detail? :slight_smile:

In the docs:

Return the calculated loss and the metrics of the current model on the given data loader dl . The default data loader dl is the validation dataloader.

So the first value is loss and the second is a metric.

You can check what metrics you have with:


Ah, I was looking at the new docs, which didn’t give me much information, thanks!

The only issue for me with learn.get_preds(dl=test_dl) is that it crops image to same format as train and valid set which is square. No big deal for classification but not good for segmentation where you want your full image size as output from prediction.

This was very helpful. Thanks

Do you know how to get around this for segmentation?

you can do the following to change the resize transform for prediction (thanks @muellerzr for helping me figure this out).


test_dl = dls.test_dl(get_image_files(‘mnist_data/train’))
# change the transforms
test_dl.after_item = Pipeline([Resize((x,y)),ToTensor])
test_dl.after_batch = Pipeline([IntToFloatTensor(),Normalize.from_stats(*imagenet_stats)])

The output of the preds is the tensor of the probabilities, is there a built in fuction that converts to classes? or it needs to be implemented by hands? because the learn.predict to given sample gives the class predicted, but using learn.get_preds on test set doesn’t, this is strange, something is missing?

1 Like