Need access to `__main__` to load model

In spite of this issue having lots and lots of threads, I still can’t figure it out.

I am attempting to push my model to a production service called ubiops, which similar to lambda asks for a .zip archive containing certain requirements. The only objective on ubiops is to run model.predict and nothing else. I give them a Deployment class and they import it into their driver and build the project. The issue is that under the current way I am saving and loading my model, the load part needs to have a custom function called get_y_fn that was used in training, and isn’t even relevant to running model.predict. I have tried a whole assortment of hacks to get that method instantiated inside the main context, but can’t do it. So now I am looking for ways to do surgery on the model pickle file. So I am looking at state dictionaries, and loading straight with torch. But I want some of the information, such as the transformation spec that is saved with my model, just not the get_y_fn.

Should I create an empty data loader with the same shape as my training data loader? Can I somehow just remove some but not all parts of the model file?

Please help!

Thanks in advance :slight_smile:

OK, so here’s what I ended up doing which enabled me to port my model saved with the training data loader and custom label_func to model that could be loaded without my custom label_func defined in __main__

I created a test data loader that consisted of a single image and its segmentation label file. I made sure to name the files so that a simple mapping could be applied using a python lamba function for label_func.

  • proddls
    • label.tif
    • base.png
dls_prod = SegmentationDataLoaders.from_label_func(
    label_func=lambda x: str(x).replace('base.png', 'label.tif'),

import dill

I then created a new learner with the same architecture as my existing model, but using the prod data loader. Note: I had to use the pickle module called dill to enable me to save the lambda function in the pickle.

import torch
from import * 

learner.path = Path(os.getcwd())
dls = torch.load(
learn = unet_learner(dls, resnet34)

Finally, I loaded my old model, and loaded the old state_dict into the prod model. The old model was persisted using export not save

old_learn  = load_learner('old_model', with_opt=False)
learn.export(fname='model1.pkl', pickle_module=dill)

Then, I was able to load 'model1.pkl' using load_learner with dill, and it worked.

However,… it’s about 5x slower to execute a prediction than the old model. Does anyone know why the new model would be slower than the original? Is it something with the procedure I followed?