Deployment Platform: AWS Lambda

Hi all,

I have created a new deployment guide showing how to take your trained fastai model and deploy to production using AWS Lambda + Amazon API GW.

Details outlined here:

Would love to get your feedback!


This seems like a really interesting way to handle inference with fastai. I’m looking forward to hearing how folks go if they try this!

1 Like

Hey Matt, I am working through the tutorial now, had to install a few extra things such as SAM and get my credentials on my AWS EC2 instance for boto3 to work, was wondering if you’d like to work on a serverless framework version together that should make this faster and easier? Feel free to email me:

Couldn’t get it working due to SAM not installing properly, not sure why.


Anyone created a fastai layer to use in Lambda yet?

Edit: I tried doing this and I couldn’t get the zip file down to the required 50MB size. And unzipped it was going to be over the 250MB limit anyway.

Found this article that describes splitting things up into multiple lambda functions but that seems like a bit more than I’m willing to put up with.

Hi @yeldarb. The layer contains a zipped file (see here) that extracts the zipped PyTorch packages when the Lambda runtime context is created which are larger than the hard limit of 250MB into the /tmp folder. You will need to import this script in your lambda code.

If you want to build your own Lambda layer checkout this GitHub project for instructions


Nice, that’s an interesting strategy.

For language models I need to be able to use the same tokenizer and numericalizer to preprocess data before feeding it to the model. So pytorch alone isn’t enough unfortunately.

That’s why I was trying to add fastai (and spacy) into the layer. I may give it another go at some point with your zip trick. But for now I got it deployed as a Google Cloud Function.

@matt.mcclean Is it possible to create Two event types using SAM yml file, one for S3 trigger and one for API trigger?

I got this working for images using the deployment walkthrough @matt.mcclean describes and it worked like a charm! Anyhow, I ran into a problem when trying to do the exact same thing with a tabular learner.

trace_input = torch.ones(1,3,224,224).cuda()
jit_model = torch.jit.trace(learn.model.float(), trace_input)

My tabular learner data is simply three float values and returns one float value.
So I tried changing this line:

trace_input = torch.ones(1,3,224,224).cuda()

to this:

trace_input = torch.ones(1,3).cuda()

and got this back:

TypeError                                 Traceback (most recent call last)
<ipython-input-27-cb886bac3aea> in <module>()
      1 trace_input = torch.ones(1,3).cuda()
----> 2 jit_model = torch.jit.trace(learn.model.float(), trace_input)
      3 model_file='resnet50_view_classification_83_classes_jit.pth'
      4 output_path = str(path/f'models/{model_file}')
      5, output_path)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/ in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace)
    634     var_lookup_fn = _create_interpreter_name_lookup_fn(0)
    635     module._create_method_from_trace('forward', func, example_inputs,
--> 636                                      var_lookup_fn, _force_outplace)
    638     # Check the trace against new traces created from user-specified inputs

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/ in __call__(self, *input, **kwargs)
    485             hook(self, input)
    486         if torch._C._get_tracing_state():
--> 487             result = self._slow_forward(*input, **kwargs)
    488         else:
    489             result = self.forward(*input, **kwargs)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/ in _slow_forward(self, *input, **kwargs)
    475         tracing_state._traced_module_stack.append(self)
    476         try:
--> 477             result = self.forward(*input, **kwargs)
    478         finally:
    479             tracing_state.pop_scope()

TypeError: forward() missing 1 required positional argument: 'x_cont'

I hope to share all the cool stuff I was able to do with medical data, and fastai deployed on AWS lambdas soon!



Seems to be a problem tracing your model and not Lambda specific as I assume you are running this code in your notebook to save the model weights using the PyTorch JIT tracing feature. You could try and reach out to the PyTorch forums to see if someone could help out there as I don’t have a lot of experience using the PyTorching JIT tracing features

Do you mean to process an image? You typically have a slightly different payload between an S3 event and an API trigger if you are submitting the raw image file. For APIGW you need to return the HTTP status code (e.g. 200) on the Lambda return payload. You could always have 2 separate Lambda functions for each event source (S3 or APIGW) but use a common library or script. This script/library could be bundled in a Lambda layer so you get code reuse.

1 Like

No the image stuff works like a charm. Its the tabular learner model that is failing to JIT.

BTW, The example code fails when you try to run it on a bunch of images (simultaneously) using a step function for example. The memory gets filled up quickly and throws an error. I fixed that by checking if the model already exists locally and not download if it does. It appears there is at least one more bug related to simultaneous usage that I am tracking down now. How should I submit my fixes? Great work by the way!

Yes, I meant to process the image. I wanted to process images in bulk through an upload event to S3 and also wanted to trigger the same lambda function using an API call for some demo purposes.

I am processing a couple of hundred images through an S3 trigger simultaneously and it works fine with me. Are you sure that the memory error is due to the multiple model downloads? Multiple downloads should cause a storage error I believe provided the downloaded files are not being overwritten/removed. But, the should overwrite the existing files and this script also removed the tar files after extracting the model from it. I think checking if the file already exist should speed up the processing to some extent.

As Matt suggested, you should reach out to Pytorch. I tried JIT with UNET earlier and it doesn’t work with UNET as well because UNET uses hooks and at that time, models with hooks couldn’t be traced.

Why doesn’t SAM allow using just normal Pytorch format besides Torchscript format?

Quote from your tutorial:
The SAM application expects a PyTorch model in TorchScript format to be saved to S3 along with a classes text file with the output class names.

Torchscript should run faster than normal PyTorch so it is better for model inference. There is no reason why you couldn’t run it in Pytorch format however


Here is a snippet of text I found in the following book that seems to provide the answer:

PyTorch Deep Learning Hands-On
By Sherin Thomas, Sudhanshu Passi
April 2019

PyTorch allows you to make a TorchScript IR through two methods. The easiest is by tracing, just like ONNX. You can pass the model (even a function) to torch.jit.trace with a dummy input. PyTorch runs the dummy input through the model/function and traces the operations while it runs the input.

The traced functions (PyTorch operations) then can be converted to the optimized IR, which is also called a static single assignment IR. Like an ONNX graph, instructions in this graph also have primitive operators that A TENsor library (ATen, the backend of PyTorch) would understand.

This is really easy but comes with a cost. Tracing-based inference has the basic problem ONNX had: it can’t handle the model structure changes that are dependent on the data, that is, an if / else condition check or a loop (sequence data). For handling such cases, PyTorch introduced scripting mode.

Scripting mode can be enabled by using the torch.jit.script decorator for normal functions and torch.jit.script_method for methods on the PyTorch model. By this decorator, the content inside a function/method will be directly converted to TorchScript. Another important thing to remember while using torch.jit.script_method for model classes is about the parent class. Normally, we inherit from torch.nn.Module , but for making TorchScript, we inherit from torch.jit.ScriptModule . This helps PyTorch to avoid using pure Python methods, which can’t be converted to TorchScript. Right now, TorchScript doesn’t support all Python features, but it has all the necessary features to support data-dependent tensor operations.

1 Like

Thanks for this Bob & Matt.

I’m also running into JIT snags trying to deploy the IMDB text-classification example from Lesson 3 – because the conversion doesn’t support parameter sharing between modules.

Would welcome anyone’s experience and/or examples successfully deploying a fastai text-classification model to lambda.

We spent some time (about a week) trying to deploy a text classification model to Lambda and at least of a month or two ago, we had determined it wasn’t supported. During that same experimentation phase we had no problem getting an image classification model in Lambda.

1 Like

Dear @matt.mcclean,

It seems like the base container from your tutorial in no longer available. Could you please share info on how it was built so that someone can follow your steps?

MacBook-Pro-2:aws artem$ sam local invoke PyTorchFunction -n env.json -e event.json
2019-06-17 02:51:07 Invoking app.lambda_handler (python3.6)
2019-06-17 02:51:07 Found credentials in shared credentials file: ~/.aws/credentials
2019-06-17 02:51:07 Found credentials in shared credentials file: ~/.aws/credentials
Downloading arn:aws:lambda:eu-west-1:934676248949:layer:pytorchv1-py36  [####################################]  108042875/108042875
2019-06-17 02:52:11 Image was not found.
2019-06-17 02:52:11 Building image...
2019-06-17 02:52:34 Requested to skip pulling images ...

2019-06-17 02:52:34 Mounting /Users/artem/Downloads/aws/pytorch as /var/task:ro,delegated inside runtime container
2019-06-17 02:53:05 Function 'PyTorchFunction' timed out after 30 seconds