Status of TorchScript in FastAI (v2)?

How is TorchScript support in FastAI?

We’re using PyTorch Lightning, and so far it’s suitable for longer research projects. However, I want to speed up making PoCs (Proof of Concepts), which I think FastAI is more suitable for.
The deployment infrastructure however relies on TorchScript (because there is also C++ involved, and we just have to pass this file to deployment team to get the model running).
I know that FastAIs standpoint is that Julia (or was it Swift?) is better suited for this, instead of limiting Python’s structure to remove the GIL, but that would be a too big project to undertake, and I’m not familiar with Julia/Swift.

In PyTorch Lightning, you can get a TorchScript version of the model in 1 line of code:
torch.jit.save(model.to_torchscript(), "model.pt")
What would I need to do in FastAI?

p.s. I searched this forum, but only a few topics touched on TorchScript, but none discussed it in detail.

1 Like

Since fastai models are torch models, you can just export to TorchScript. Now, we’re not PyTorch lightning, so base PyTorch doesn’t have the fancy model.to_torchscript() functionality. Instead we need to do it ourselves. Below is a full end-to-end example:

from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2,
    label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))
learn = cnn_learner(dls, resnet18)

We’ll need a batch of data to trace with, and set our model to eval:

x,_ = dls.one_batch()
learn.model.cuda()
learn.model.eval();

Finally, at the time of writing this, fastai has made some modifications to the Tensor class in PyTorch to have it working better with tensor subclasses. As a result we have our own version of requires_grad_, but for it to work with jit we need to have it accept a requires_grad param. Below is the monkey-patch to get there:

@patch
def requires_grad_(self:TensorBase, requires_grad=True):
    self.requires_grad = requires_grad
    return self

Now you can simply use torch.jit.save and torch.jit.trace (to trace the model) and it’s exported:

torch.jit.save(torch.jit.trace(learn.model, x), 'model.pt')
3 Likes