How to extract image transformation code from Learner object?

Hi there,

I am working on productionizing a computer vision model that was trained using FastAI. I was sent a saved model file which I am able to successfully load on my machine using load_learner('filename'). I was able to test the model by predicting on a few images with learner.predict('my_image_filename.jpg'), and found that the predictions matched what I expected.

In our production environment, I will need to perform model inference on a single image, but it is a requirement that the inference uses a pure PyTorch forward pass (the system would require significant rearchitecting to accommodate a fastai Learner.predict inference).

I need to understand how Learner.predict loads and processes a single image so that I can replicate this in pure PyTorch. Is there any way to extract that portion of the code so that I can run it to load/process/transform my images into a tensor that can then be sent through the forward method of learner.model? I’ve been reading through the docs and forums to try and find this, but I haven’t been able to find where that specific processing code exists, though it must exist since it has to be executed for the predict function to work.

I have tried to reverse engineer the methods reading the docs and the forums and re-create them in pure PyTorch, and I’ve got somewhat close to results that match Learner.predict, but still significantly off such that the model performance will be greatly degraded if I use it in production.

Can anyone point me to any attributes/methods that might help uncover what the predict method is doing to load/process images?

For some more specific context, the dataloader that was used to train the model was defined as

dls = ImageDataLoaders.from_df(train_df_f, 

I’ve attempted to recreate this in pure PyTorch using several different methods, and here is the one that got the most similar output to Learner.predict, despite still being significantly different on many test images:

from PIL import Image

# Load image
input_image ="my_image.jpg").convert("RGB")

# Convert to torch.Tensor values in [0, 255]
converter = PILToTensor()
tensor_image = converter(img)

# Apply necessary transforms 
tensor_image = (tensor_image.float() / 255.0)
resize = T.Resize(224)
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
inputs = normalize(resize(tensor_image)).unsqueeze(0)

model = learner.model
with torch.no_grad():
    predictions = model(inputs)
predicted_probabilities = predictions.softmax(1)

Using this PyTorch code gives directionally correct predictions, but with much less confidence than the fastai predict method.

Please let me know any pointers, really hoping to find a way to get to parity without rewriting my ML inference library to accommodate fastai for this single use case.

I do a full lesson on exactly this in my course (Walk with fastai revisited, please consider buying it as it explains more on the code as the source code has no information in it :hugs:) but you can see the source here which will get the job done for you. It fully recreates what fastai does during deployment without using the framework:

1 Like

Hi @muellerzr , thanks so much for sharing the notebook! I see that it shows how you can map the image transformations 1:1 from fastai to torchvision exactly, which is indeed what I am eventually looking for, and I look forward to using this as a reference. However, I am mostly wondering how I can discover which transformations the Learner that I have loaded is actually applying to the images. Even if I know how to implement them exactly using torchvision, I will still need to know exactly what transforms Learner is applying, and in what order, and this is what I am struggling to find / parse from fastai source code.

Could you advise on that front? For example, from my experiments, I have the feeling that at some point, the fastai learner must be applying normalization. I have not been able to find an attribute that shows that normalization process, and the parameters that were used. Same with resize/crop, I’m still not sure what parameters it’s using to do that.

I think it’s possible the subtleties (eg hyperparam choices) of the normalization and resize algorithm are leading to the differences I’m seeing in my torch and fastai models, and I just want to find out how to get the normalize and resize objects from my fastai Learner so I can inspect the parameters they’re using and also figure out what order to apply them etc. I know this should be possible since under the hood these operations with specific parameters are definitely being applied in Learner.

Again, the course talks about this in depth as it’s a complicated question. If possible would recommend that route :slight_smile:

Thanks for the suggestion. I appreciate the link to your course, which appears very relevant and comprehensive, but I do not think it would be very efficient for me to become a fastai user at this point in time (it would be quicker for me to retrain the model entirely using a more PyTorch developer-friendly framework like lightning, or just pure pytorch, than to become an intermediate-level user of fastai). If I decide in the future to learn more about the framework in depth, I will gladly look to your course.

I am just looking for a way to expose the bespoke image transformations that the Learner is applying to my images. This must be just a few simple object attribute accesses, right? I have been looking through the source code as well, but I’m sure you can understand that as someone new to the framework, it’s not particularly straightforward to find where the bespoke transformations are happening.

Maybe it would help to simplify my question – I want to pass one image (saved to my machine as my_image.jpg through a TfmdDL object (in this case this is my Learner.dls.valid) to return the transformed version of the image. How can I do this?

I’ve been able to find these transforms implemented on my TfmdDL, but still no normalization. Where would I be able to find the normalization? @muellerzr

Check it again after you’ve used vision_learner, if you used that it will add Normalization.from_stats(*imagenet_stats)

(see learner.dls.after_batch)