Train model with fastai, predict with Pytorch (Lightning)

Hi guys,

I have a Pytorch Lightning application that I use to train a CNN for image classification from scratch using NNI. I later use the trained model in production to do predictions.

As transfer learning seems to be much simpler and faster with fastai, I would like to migrate the training component to fastai while keeping the old prediction code as it is for now (Pytorch Lightning). I already used a vision_learner to do that and exported the resulting model as .pkl (.pth) file.

In theory, it should be straight-forward to load the model in Pytorch:

data_module = DataModule('/path/to/prediction/data')
trained_model_path = '/path/to/model'

model = pl.LightningModule()
model.load_state_dict(
    state_dict=torch.load(trained_model_path),
    strict=False
)
model.eval()

trainer = pl.Trainer(
    callbacks=[PredictionCallbacks('/path/to/prediction/data', 50)]
)
trainer.predict(
    model=model,
    datamodule=data_module
)

Nevertheless, using the pickle version of the model (.pkl) gives me the following error message:

TypeError: Expected state_dict to be dict-like, got <class 'fastai.learner.Learner'>.

When I use the .pth version of the model, the error message is:

pytorch_lightning.utilities.exceptions.MisconfigurationException: `Trainer.predict` requires `forward` method to run.

Does anyone have experience with how to use FastAI models in Pytorch Lightning? I am probably missing something here, not too familiar with FastAI yet.

Any help appreciated.

Best, Tobias

Your LightningModule is completely empty. There is nothing you can load the state_dict into and it has no forward method. There is no way to infer the model architecture from only the state dict. The state dict only includes the layers weights.

What network architecture did you use in your fastai model? I would create a LightningModule with the architecture in it and define a forward method, and then the second approach with the .pth file should work.

Ok, I already thought so but don’t know how to initialize the model architecture and forward method in the case I use a pre-trained model.

I used a Resnet-50 CNN in my learner:

learn = vision_learner(dls, resnet50, metrics=error_rate)

and later save the model like this:

torch.save(
    learn.model.state_dict(),
    '/path/to/model.pth'
)

I found help in the Pytorch Lightning Documentation:
https://pytorch-lightning.readthedocs.io/en/stable/advanced/transfer_learning.html#use-a-pretrained-lightningmodule

I created a LightningModule like this:

import pytorch_lightning as pl
import torchvision.models as models
import torch
import torch.nn as nn

class Resnet50(pl.LightningModule):
    def __init__(self):
        super().__init__()

        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        return x

    def predict_step(self, batch, _batch_idx):
        return self(batch)

And used it in my code:

model = Resnet50()
model.load_state_dict(
    state_dict=torch.load(trained_model_path),
    strict=False
)
model.eval()

First problem seems to be tackled now, but I still get an error:

File "/../lib/python3.10/site-packages/torch/nn/modules/conv.py", line 453, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: expected scalar type Byte but found Float

I guess, I need to dig a bit deeper into the forward method.

1 Like

Ok, I had to call .float() on my tensor to fix the above mentioned error.

with torch.no_grad():
   representations = self.feature_extractor(x.float()).flatten(1)

Now that I got this working, I discovered that ResNet-50 is trained with images of 224 pixels x 224 pixels. While I can fine_tune the model with images with other dimensions in FastAI, it seems not to work in the prediction stage if I understand the error message correctly.

RuntimeError: mat1 and mat2 shapes cannot be multiplied (262144x1 and 2048x2)

Seems that I was wrong with my initial assumption. I updated the batch size and could get rid of the above mentioned error.