I post some information concerning prediction on a test set I’ve gathered from delving a bit into the fastai library. I think they could be useful since Jeremy started talking here about Kaggle competitions.
Nothing fancy and nothing new, but maybe helpful for somebody.
Here we get an
ImageClassifierData object by calling the method
data = ImageClassifierData.from_paths(PATH, tfms=tfms_from_model(arch, sz))
By taking a look at the method’s arguments (shift+tab), we see there is a
test_name argument, which is the name of the folder that contains test images. Let’s place the test folder in PATH, and then we can call
from_paths like this:
data = ImageClassifierData.from_paths(PATH, tfms=tfms_from_model(arch, sz), test_name='test')
Now we can further explore the
data object to get a list of valid attributes by using the following function:
Among them, I focused on these:
classes classes names
val_y class labels for validation data
trn_y class labels fo training data
val_ds validation dataset
trn_ds training dataset
test_ds test dataset
If we use
type() on val_ds, trn_ds and test_ds, we find that these are objects of type ‘FilesIndexArrayDataset’.
Delving a bit into the library, we find the following inheritance path:
FilesIndexArrayDataset --> FilesArrayDataset --> FilesDataset --> BaseDataset --> Dataset
Where ‘Dataset’ is a PyTorch abstract class representing a Dataset.
Then, we can also take a look at the corresponding files in each dataset by printing the ‘fnames’ list. For the first 10 files we have
Now, since we’ve loaded the test dataset too, it’s possible to make prediction on it. After training, take a look at
learn.predict(). In the
learner.py file from the fastai library, you can see that this method takes an optional ‘is_test’ argument.
predict_with_targs(), which finally checks whether
is_test is True or False. If True use the test dataset, otherwise use the validation dataset (to be more precise it chooses between
data.val_dl, which are
ModelDataLoader objects, but I haven’t dug deeper yet).
It’s also possible to call
learn.TTA() with the
is_test argument set to True if you want to do test time augmentation.