Error loading saved model with custom loss function

My model has a custom loss function as nn.Module subclass. When running learn.export() I first get the warning

/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:250: UserWarning: Couldn’t retrieve source code for container of type FocalLoss. It won’t be checked for correctness upon loading.
"type " + obj.name + ". It won’t be checked "

then when invoking load_learner I get

AttributeError: Can’t get attribute ‘FocalLoss’ on <module ‘main’ from ‘app/server.py’>

Apart from redefining all the custom methods and classes on the production server, any thoughts on how to save and load everything with one action?

Normally it should work if your redefine your custom methods and classes (note that it want’s to find it in main). My advice would be to package everything you need in a module.py file you will have in the working directory, that way pickle will know how to handle it better.

Thanks, @sgugger. I first tried to define some dummy classes in the main module and failed, but then I moved them into a separate module and made an explicit import and it worked.

@sgugger a follow up inquiry to this.

I found that if I define my custom functions and load my learner in one script (script2.py), I can’t import and run it another (script1.py) i.e. I get this error running script1.py

AttributeError: Can't get attribute 'custom_loss' on <module '__main__' from 'script1.py'>

The workaround I have been using for this is to always load the learner in the main script I am running. However I am trying to do something with fastai and celery, and it seems the way celery works that isn’t possible.

Am I missing something here as to defining my custom functions? What is the reason the loader needs to learned in the main script, and is there a work around?

You should define all your custom layers/loss functions in a separate module that you import. That way you can load your learner easily.

@sgugger Yes, I’ve found organizing it that way helps. But I still need to run load_learner in the __main__ script, is that correct? Or am I going crazy? haha

I have found a work around for my celery scheduler which involves redefining the Unpickler in the pickle module and then passing it to torch.load in a custom load_learner. This is obviously a lot of overhead if I am missing something simple.

@ecatkins Were you able to find a solution for this?? I am getting the same error when I run with Celery worker.

“AttributeError: Can’t get attribute ‘FastAiRobertaTokenizer’ on <module ‘celery.bin.celery’ from ‘/usr/local/lib/python3.7/site-packages/celery/bin/celery.py’>”

I was wondering where should I define the class or import it?

@kanz.2890

I used the following to redefine pickle module that gets passed to load_learner

import imp, sys
pickle2 = imp.load_module('pickle2', *imp.find_module('pickle'))

class CustomUnpickler(pickle2.Unpickler):

    def find_class(self, module, name):
        try:
            return super().find_class(__name__, name)
        except AttributeError:
            return super().find_class(module, name)


def load_learner(fname, cpu=True):
    "Load a `Learner` object in `fname`, optionally putting it on the `cpu`"
    distrib_barrier()
    pickle2.Unpickler = CustomUnpickler
    res = torch.load(fname, map_location='cpu' if cpu else None, pickle_module=pickle2)
    if hasattr(res, 'to_fp32'): res = res.to_fp32()
    if cpu: res.dls.cpu()
    return res

2 Likes

Seems like imp is being deprecated and we are asked to use importlib. Is there anyway that we can translate the same to use importlib?

have you tried just replacing import imp with import importlib?

We cannot do that since importlib doesn’t have the load_module attribute.

Checking the documentation, apparently exec_module is the new function to use? Have you looked into this?

I tried that but somehow I can’t get it to work. So I came up with this instead, but this doesn’t work either -

pickle2 = importlib.import_module(‘pickle2’, *importlib.util.find_spec(‘pickle’))

class CustomUnpickler(pickle2.Unpickler):

def find_class(self, module, name):
    try:
        return super().find_class(__name__, name)
    except AttributeError:
        return super().find_class(module, name)

def load_learner(fname, cpu=True):
    distrib_barrier()
    pickle2.Unpickler = CustomUnpickler
    pickle2 = importlib.import_module('pickle2', *importlib.util.find_spec('pickle'))
    res = torch.load(fname, map_location='cpu', pickle_module=pickle2)
    return res

Could you please help me with this?