Trying to match fastai and pytorch inference

I need to run an inference pipeline which uses a fastai (v2) resnet18 binary classifier model along with a few pure pytorch models. The pytorch models were trained with different pre-processing steps. I am trying to use a single dataloader to read the files just once and prepare different tensors for the various models.

I have the following code for batch inference with the fastai model:

from fastai.vision.all import *

learn = load_learner('export.pkl')
learn.model.to('cuda:0')

# sample_files is a list of file paths
test_dl = learn.dls.test_dl(sample_files)
preds, _ = learn.get_preds(dl=test_dl)
print(preds)

and this outputs:

tensor([[0.8155, 0.1845],
        [0.9797, 0.0203],
        [0.9990, 0.0010],
        [0.7240, 0.2760],
        [0.1400, 0.8600],
        [0.9983, 0.0017],
        [0.4284, 0.5716],
        [0.8841, 0.1159]])

I am trying to now extract the pytorch model out of this fastai learner and write my own data loader.

The fastai model was built with the following data block:

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   get_items=get_image_files,
                   get_y=parent_label,
                   splitter=GrandparentSplitter(),
                   item_tfms=Resize(img_size))

The only preprocessing done during training is the resize.

I wrote the following custom data loader to do the same pre-processing done during training i.e. just resize:

import cv2
from torch.utils.data import Dataset, DataLoader

class ImageDataset(Dataset):
    def __init__(self, file_paths, img_size):
        self.file_paths = file_paths
        self.img_size = img_size
    
    def __getitem__(self, index):
        fpath = self.file_paths[index]
        img = cv2.imread(str(fpath))
        img = cv2.resize(img, (self.img_size, self.img_size), 
                         interpolation=cv2.INTER_AREA)
        # transpose (height, width, channels) to (channels, height, width)
        img = img.transpose(2,0,1)
        return img
    
    def __len__(self):
        return len(self.file_paths)

And then this to do the inference:

import torch.nn.functional as F

dataset = ImageDataset(sample_files, img_size=480)
loader = DataLoader(dataset, batch_size=32, num_workers=3)

with torch.no_grad():
    for batch in loader:
        batch = batch.float().to('cuda:0')
        preds = learn.model(batch)
        print(preds)
        preds = F.softmax(preds, dim=-1)
        print(preds)

However, this gives the following output:

tensor([[  128.9455, -1014.5818],
        [  159.4461,  -971.7758],
        [  212.9348,  -918.6669],
        [   99.4558, -1047.9894],
        [  108.8564, -1011.8091],
        [  109.2117, -1040.3489],
        [  134.9720, -1035.7109],
        [  127.0243,  -997.6915]], device='cuda:0')
tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]], device='cuda:0')

I was hoping to see the same probabilities as above. I am not sure why I am getting a mismatch between the two. Appreciate your feedback. Thank you.

1 Like

Hi Arun,

This is a great question. The issue might be in relation to this assumption: “custom data loader to do the same pre-processing” – while you might be getting the same sized tensor for each image, the images themselves may not come out identical between the two runs.

This is because Fast.AI uses PIL (python imaging library) for it’s item transform resize function and there are some default (best practices) parameters set when using this – while your Pytorch run is using CV2 where the default parameters/operations are not the same: Data augmentation in computer vision | fastai

For example, the Fast.AI version is using a bilinear interpolation method while your Pytorch version is using “INTER_AREA”.

Another (key) difference could be because of the default padding mode; Fast.AI is using reflection padding by default. You can see what reflection padding looks like here: vision.transform | fastai

1 Like

Hi Ali,

Thank for your reply.

I rewrote the dataloader with fastai Resize itself like this:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class ImageDataset(Dataset):
    def __init__(self, file_paths, img_size):
        self.file_paths = file_paths
        self.img_size = img_size
        self.transform = transforms.Compose([transforms.ToTensor()])
    
    def __getitem__(self, index):
        fpath = self.file_paths[index]
        img = PILImage.create(fpath)
        rsz = Resize(self.img_size)
        img = rsz(img)
        img = self.transform(img)
        return img
    
    def __len__(self):
        return len(self.file_paths)

The results look better now, but still do not match what get_preds gives:

tensor([[-0.0205, -0.0428],
        [ 0.4740,  0.6269],
        [ 0.4004, -0.3127],
        [ 0.2432, -0.6500],
        [ 0.2652,  0.9056],
        [ 1.1341, -3.0740],
        [ 0.2427,  0.0118],
        [ 0.4215, -0.8222]], device='cuda:0')
tensor([[0.5056, 0.4944],
        [0.4619, 0.5381],
        [0.6711, 0.3289],
        [0.7096, 0.2904],
        [0.3452, 0.6548],
        [0.9853, 0.0147],
        [0.5575, 0.4425],
        [0.7762, 0.2238]], device='cuda:0')

Glad I finally found the solution from this blog post:

I trained the model again, but this time with:

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   get_items=get_image_files,
                   get_y=parent_label,
                   splitter=GrandparentSplitter(),
                   item_tfms=Resize(img_size, 
                                    method=ResizeMethod.Squish,
                                    pad_mode=PadMode.Reflection))

and then wrote the custom dataloader like this:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class ImageDataset(Dataset):
    def __init__(self, file_paths, img_size):
        self.file_paths = file_paths
        self.img_size = img_size
        self.preprocess = Pipeline([Transform(PILImage.create),
            Resize(self.img_size, 
                   method=ResizeMethod.Squish,
                   pad_mode=PadMode.Reflection),
            ToTensor,
            IntToFloatTensor,
            Normalize.from_stats(*imagenet_stats, cuda=False)
        ])
    
    def __getitem__(self, index):
        fpath = self.file_paths[index]
        img = self.preprocess(fpath)
        img = img.squeeze(0)
        return img
    
    def __len__(self):
        return len(self.file_paths)

This matches what get_preds gives :slight_smile:

1 Like

Would like to add a surprising finding here.

The pytorch custom data loader with 3 data loader workers runs 2.6x faster than fastai get_preds for the same batch size! I explicitly set test_dl num_workers like this:

test_dl = learn.dls.test_dl(test_files, bs=32, num_workers=3)

but this didn’t speed things up.

So here you ImageDataset in not fully written in native torch since ResizeMethod.Squish, Normalize.from_stats(), etc are included from fastai. So one still needs to import fastai.
I am wondering if I can have the equivalent inference part without depending on fastai?!

@Omayma Good question. I don’t have the code in pure pytorch as of now, but will attempt and post if I get it working.

1 Like

This gave me similar results. My example about multilabel classification. I just haven’t written a dataloader in torch yet.

import json

import torchvision.transforms as T
from fastai.vision.all import *

HP_PATH = 'train_job_image/model_hyperparams.json'
training_images_path = 'data/imgs/'
training_dataset_path = 'data/labeled_data_train_valid.csv'
test_data_path = 'data/imgs_test/'

# read json including parameters values
with open(HP_PATH, 'r') as model_hparams_json:
    model_hparams_dict = json.load(model_hparams_json)

img_size = model_hparams_dict['RESIZE_VALUE']

# read train data
df_train = pd.read_csv(training_dataset_path)

# TRAINING -----------------------------
# datablock for multilabel classification
dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
                   get_x=ColReader('image_id',
                                   pref=training_images_path,
                                   suff=''),
                   get_y=ColReader('tag', label_delim=','),
                   splitter=ColSplitter('is_valid'),
                   item_tfms=Resize(img_size,
                                    method=ResizeMethod.Squish,
                                    pad_mode=PadMode.Reflection
                                    ))

# load abd train
img_dls_01 = dblock.dataloaders(df_train)

learn_multi = cnn_learner(dls=img_dls_01, arch=resnet34)

learn_multi.fine_tune(
    model_hparams_dict['EPOCHS'],
    base_lr=model_hparams_dict['BASE_LR'],
    freeze_epochs=model_hparams_dict['FREEZE_EPOCHS']
)

# INFERENCE: FASTAI ------------
tst_files = get_image_files(test_data_path)
tst_dl = learn_multi.dls.test_dl(tst_files)
preds_f, _ = learn_multi.get_preds(dl=tst_dl)

# INFERENCE: TORCH --------------
test_files = glob.glob(f'{test_data_path}*')

# a quick loop to check
with torch.no_grad():
    for ff in test_files:
        pil_image = Image.open(ff)
        resize = T.Resize([img_size, img_size])
        res_pil_image = resize(pil_image)
        timg = T.ToTensor()(res_pil_image)
        with torch.no_grad():
            pred = learn_multi.model(timg.unsqueeze(0)).sigmoid()
            print(pred)

I added an example here with more details.