Predictions of learner.get_preds() and learner.predict() differ tremendously

Hello,

I have trained a CNN for a binary classification task.
The data loader creation looks like this:

final_size = 512
bs = 35

data = fastai.data.block.DataBlock(
    blocks=(TileImageBlock, fastai.data.block.MultiCategoryBlock),
    get_x=lambda x: x, 
    get_y=lambda x: x.get_labels(),
    splitter=fastai.data.transforms.FuncSplitter(lambda x: x.get_dataset_type() == shared.enums.DatasetType.validation),
    item_tfms=fastai.vision.augment.Resize(size=final_size, method = 'squish'),
    batch_tfms=fastai.vision.augment.aug_transforms(flip_vert=True))

dls = data.dataloaders(patient_manager.get_tiles(dataset_type=DatasetType.train)\
                    +patient_manager.get_tiles(dataset_type=DatasetType.validation), 
                   bs=bs, 
                   verbose=False)

(TileImageBlock and the object type “Tile” which is returned by patient_manager.get_tiles are custom objects that represent a Tile of a whole-slide image and make it possible to extract png images on the fly from whole-slide images)

The learner creation like this:

learner = cnn_learner(dls=dls, 
                     arch=arch, 
                     metrics=[fastai.metrics.accuracy_multi],
                     pretrained=True,
                     path=PATH/'models'/f'{n}-{arch.__name__}')

Like I already mentioned, there are two classes and each tile belongs to exactly one of them.
I handle it as a multilabel classification, so the learner outputs two probabilities.

During Inference time, when I use learner.get_preds() to get predictions for the validation set, the predictions are pretty much rubbish. For almost every tile the predicted probability is very high for class B and very low for class A.
When I use learner.predict() instead and iterate over the whole validation set, the predicitions are very accurate and for almost all tiles correct.
I digged deep into my code but could not find an explanation for this behavior.
learner.predict also calls learner.get_preds behind the scenes.

Does anynone of you have a hint or an idea what might cause this strange behavior?
The only problem I have with learner.predict is, that it is very slow compared to learner.get_preds since it does not predict batchwise and I got a lot of tiles to predict, so that it is really not a good alternative.

Thanks in advance

Christoph

2 Likes