How to specify test set in DataBlock?

I’m wondering is it possible to do three-way split to split my data into non-overlapping training/testing/validation sets?

Using the following example to illustrate:

from import URLs, untar_data
from import get_image_files, parent_label, GrandparentSplitter
from import DataBlock, CategoryBlock
from import ImageBlock
from fastai2.torch_core import doc

dest_path = untar_data(URLs.MNIST)

dblock = DataBlock(
  blocks=(ImageBlock, CategoryBlock), 
  splitter=GrandparentSplitter(train_name='training', valid_name='testing')
dls = dblock.dataloaders(dest_path)

In the above code, we use GrandparentSplitter to split the data into training set (the “training” folder) and validation set (the “testing” folder). How do I specify another test set (which is separate from the training and validation sets) that I can use to evaluate the final model?


Assuming you have another folder that contains images that are not used, you can try like something below.

test_files = get_image_files(test_path)
test_dl = dls.test_dl(test_files)

I’ll build off Vishnu’s excellent answer and add my own tweaks:

# learn has been fit above
test_dl = learn.dls.test_dl(test_files, with_labels=True)

See the very bottom of this post:


Thank you both @VishnuSubramanian @sut !
I guess my follow-up question would be how do I get metric on the test_dl?

I figure the following code would give me predictions for the test data

test_files = get_image_files('... path to test data...')
test_dl = learner.dls.test_dl(test_files, with_label=True)
pred_probas, _, pred_classes = learner.get_preds(dl=test_dl, with_decoded=True) 
  1. Is there a way to get labels from test_dl?
  2. Is there an easy way to get metric (error rate in this case)?
  1. Get labels? Yes, those are in the 1-th position of the in get_preds return value. In your code you saving them to _
  2. from sklearn.metrics import accuracy_score
    err_rate = 1 - accuracy_score(actual_classes, pred_classes)

If your test files were originally in the items grabbed by get_image_files, you can write a custom splitter to not return two but three list of indices (look at the code of GrandParentSplitter and adapt it). That will make your test set part of your dls in the third subset (you can then access it as dls.loaders[2]).

To validate on a new dataloader:


To validate on your third subset:


Thanks @sut!
I was referring to the actual labels for the test data, not predicted labels.
I would expect that actual_classes and pred_classes would be PyTorch tensors, so I was expecting some utility function in the library to calculate the metric.

Nice, thanks for the tips @sgugger!

Is there a way to a set way to record ds_idx=2 during the training loop?

If not, would this be useful PR I could build? It goes well with a situation I’m dealing with right now.

There is no direct way, but there is a way to add things to what’s logged (see the metric sections, there is one for just any quantity you set) and you can add a callback making another validation on ds_idx=2 after the normal validation.

1 Like

I’m stuck, a little advice would be much appreciated. I am trying to call learn.validate within a callback method, but this gets me stuck an infinite loop.

I assume this is because the .validate contains references to other events. How to get around this? Or is there an example of a callback which calls a method on the leaner that I can use as a template?

You can’t call validate in a callback without a recursive loop since validate calls after_epoch, which then calls the callbacks, again and again.
You should use the private method _do_validate after setting self.dl to avoid this.

This worked, thank you. Leaving a prototype here for anyone who finds this thread:

class TestSetRecorder(Callback):
    def __init__(self, ds_idx=2, **kwargs):
        self.values = []
        self.ds_idx = ds_idx
    def after_epoch(self):
        old_log = self.recorder.log.copy()
        self.learn._do_epoch_validate(ds_idx=self.ds_idx, dl=None)
        self.recorder.log = old_log

Full notebook

I had to get my values out of recorder.log which is a little hacky but seems to work.