Export FastAI to PyTorch

Hi, I hope you are staying safe! I know there are a lot of similar questions on this forum but I haven’t been able to find a conclusive answer. I want to export a FastAI model for use with tf.js, and I plan to do this using MMdnn, Microsoft’s DNN conversion software. It supports PyTorch, but I am unsure if a FastAI PKL (made using learn.export()) or PTH (made with learn.save(path)), will work with PyTorch. If they do, could someone tell me how I could load it into PyTorch and export that? Or, could I directly import the PTH to MMdnn? I know that there are a lot of posts on this, but none answer my question.

2 Likes

Once you export, it’s pretty much regular PyTorch by then. You can simply pull the state_dict out and you’re good to go (check the dict and make sure you pull just the model, IIRC the optimizer state can get exported too). What gets confused easily is fastai will put a few more layers on the end of the model so it’s not the “raw” ResNet18, it’s a ResNet18 body with some more layers, something to keep in mind. And some stuff for the DataLoaders, hence pulling out the models state_dict

@mullerzr Thank you for you response. I just want to confirm that I can simply learn.export(), and then load it into PyTorch and run MMdnn. If so, how to I import it into PyTorch.

torch.load(), just as you would any regular torch. But again, be careful just to pull the state_dict of your model. Which I believe regular export does a bit more than that. So if you want to get rid of fastai completely you should look into how to store away state_dicts and then recreate the model in PyTorch (as all fastai models are inthemselves PyTorch models so the methods are the exact same)

This resource may be helpful:

https://pytorch.org/tutorials/beginner/saving_loading_models.html

(It’s learn.model.state_dict())

1 Like

So can I just torch.load(learn.model.state_dict()?

Not quite. So, the best scenario is here is to save your model’s state dict to a seperate file on your system we can then later call a torch.load() on. Note this will just contain the weights, none of the fastai fancy things (which is our goal here).

To do so, we’d do something like so:

torch.save(learn.model.state_dict(), path/'myModel.pth')

Now when you’re ready to move over to your MMdnn, you can simply do

torch.load(path/'myModel.pth')

The reasoning is the moment we pull out that state_dict, I could even just do:
dic = learn.model.state_dict() for instance. The moment we do that it’s completely and entirely raw PyTorch. You’ll need to set up your MMdnn environment to rebuild your model without the library and instead with PyTorch (Just copy over fastai layers until you get all the right ones you need), and you’re good to go. (Not 100% positive on the last point there, but fairly certain. Try it without, if it doesn’t work, you’ll need it :slight_smile: )

7 Likes

@muellerzr I am really sorry, I am new to AI. Thank you for bearing with me. In that code, Is torch.save() creating a torch model? Will I need to retrain it? Also, is the path/path.pth the exported fastai file created by `learn.save()?

No problem! Where fastai ends and torch begins can get real confusing at times :slight_smile:

Every single fastai model you use is actually a torch model! fastai is specifically a training library (more or less) for PyTorch (hence why it’s considered a wrapper library). So it’s already one. torch.save() simply saves away the state dictionary (all of the weights of our model, the layers, etc) as a torch pickle file (.pkl) for us to load in later.

Re: path, no this is simply whatever you wish to save your model as. So path can point to anything, and then the path.pth is your filename (ensure you have the .pth in there).

Also do note that we’re not using learn.save() here. We’re using torch.save() instead to remove the fastai wrapping

1 Like

@muellerzr Thanks so much for your amazing help. Sorry to bug you again. If I understand correctly, I have to pass my learner’s state dict to the torch.save, and then it will output a torch PKL. How do I save a PTH from the PyTorch model. Do I just use the standard torch.save()?

MMdnn Docs on PyTorch:

Extract PyTorch pre-trained models

You can refer PyTorch model extractor to extract your pytorch models.

$ mmdownload -f pytorch -h
Support frameworks: ['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn']

$ mmdownload -f pytorch -n resnet101 -o ./
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /home/ruzhang/.torch/models/resnet101-5d3b4d8f.pth
███████████████████| 102502400/102502400 [00:06<00:00, 15858546.50it/s]
PyTorch pretrained model is saved as [./imagenet_resnet101.pth].

Convert Pytorch pre-trained models to IR

You can convert the whole pytorch model to IR structure. Please remember for the generality, we now only take the whole model pth, not just the state dict. To be more specific, it is save using torch.save() and torch.load() can load the whole model.

$ mmtoir -f pytorch -d resnet101 --inputShape 3,224,224 -n imagenet_resnet101.pth

Please bear in mind that always add --inputShape argparse. This thing is different from other framework because pytorch is a dynamic framework.

Then you will get

IR network structure is saved as [resnet101.json].
IR network structure is saved as [resnet101.pb].
IR weights are saved as [resnet101.npy].

Convert models from IR to PyTorch code snippet and weights

You can use following bash command to convert the IR architecture file [inception_v3.pb] and weights file [inception_v3.npy] to Caffe Python code file[pytorch_inception_v3.py] and IR weights file suit for caffe model[pytorch_inception_v3.npy]

Note: We need to transform the IR weights to PyTorch suitable weights. Use argument -dw to specify the output weight file name.

$ mmtocode -f pytorch -n inception_v3.pb --IRWeightPath inception_v3.npy --dstModelPath pytorch_inception_v3.py -dw pytorch_inception_v3.npy

Parse file [inception_v3.pb] with binary format successfully.
Target network code snippet is saved as [pytorch_inception_v3.py].
Target weights are saved as [pytorch_inception_v3.npy].

Generate PyTorch model from code snippet file and weight file

You can use following bash command to generate PyTorch model file [pytorch_inception_v3.pth] from python code [pytorch_inception_v3.py] and weights file [pytorch_inception_v3.npy] for further usage.

$ mmtomodel -f pytorch -in pytorch_inception_v3.py -iw pytorch_inception_v3.npy -o pytorch_inception_v3.pth

PyTorch model file is saved as [pytorch_inception_v3.pth], generated by [pytorch_inception_v3.py] and [pytorch_inception_v3.npy]. Notice that you may need [pytorch_inception_v3.py] to load the model back.

Yes. It won’t output a pkl, unless you specify that output. It should be pth. pkl means Pickle and it’s what fastai uses. (Simply put the .pth in your torch.save()

@muellerzr Thanks so much for your help! It really means a lot!

@muellerzr Sorry to keep bugging you. I unfortunately could not get MMdnn to work. I think that I can use ONNX for what I was trying to do, but I can’t really understand the online instructions, as I am unfamiliar with PyTorch. Could you give some guidance on doing that. Thanks in advance for any help!

Hi @PotatoHeadz35!

I’m at the same place as you!
I’ve just finished my first model, pulled the state_dict with torch.save() to save it as as pth file.

Now I need to convert it to a TF.js format to be able to load it on a web application.

Did you manage to do this?

Best regards.

Thanks in advance.

Hi!
This is from chapter 2
While exporting a .pkl file, we did

learn.export()
path = Path()
path.ls(file_exts='.pkl')

which gives
(#1) [Path('export.pkl')]
My question is, where is export.pkl located? I cannot find it anywhere!

learn.export(PATHNAME) will save it wherever you would like. Just fill in the PATHNAME

Hey everyone I have exported my model as a .pth model using torch .save_dict now how can I use this model for single image inference, so how can I use the .scn image of aperioimage scope which was not in previous data, the model was trained on .tiff images now when I try loading the images they turn out to be black and white patches.