How do you get predictions on labeled test data?

This is definitely a rookie question and I’ve tried finding the answer, but been unsuccessful, hence the post, so thanks in advance for any help!

I am on lesson 1 trying out my own data set and I have imagenet style labeled data. I have training data and test data, but no validation data, so I used valid_pct in the ImageDataBunch.from_folder() method.

As it happens, my test data is also in labeled folders. I have a decent model that I would now like to try on the test data to see how it does. But I cannot find the API calls to make this happen.

FYI:
My data is organized in test and train folders (and in respective label folders under there). I used ImageDataBunch.from_folder() to load the data with the parameter test=“test” set so that it would load the test data.

Issues:

  1. I see that the ImageDataBunch object does have a test_ds property (which I presume is the loaded test data) but this object only has the images loaded (x), the labels (y) are empty. I need these labels to compare accuracy of the model.

  2. When I try learn.get_preds(ds_type=“test”), I thought I would get back predictions for test data, but based on the number of results returned it is returning predictions for training data.

Once again, assistance would be most appreciated :slight_smile:

1 Like
  1. Since your test data is organised like validation data, you can use it with the trained model as if it were validation data and get predictions on this new set of validation data. I do just that in this notebook. For example,
data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=tfms, bs=64, size=256)
data.normalize(imagenet_stats)
learn.data = data
  1. By default get_preds() gives predictions for the validation data. To predict on test data (not test data pretending to be validation data like in 1.) you use get_preds(ds_type=DatasetType.Test).

Thank you very much - your reply and example was very helpful in understanding the APIs. I was able to get your approach #1 above working. I am still not totally sure what I am doing wrong in my invocation of get_preds, but since you have helped me solve my original problem I do not need to figure that out right now. Thanks again!

1 Like