Loading custom Databunch fails with attribute error

Hi,

. I have created a custom databunch which I am trying to load using load_data. But I am getting an attribute error -

File “/home/views.py”, line 641, in get
path, r"/home/data_save.pkl")
File “/usr/local/lib/python3.7/site-packages/fastai/basic_data.py”, line 281, in load_data
ll = torch.load(source, map_location=‘cpu’) if defaults.device == torch.device(‘cpu’) else torch.load(source)
File “/usr/local/lib/python3.7/site-packages/torch/serialization.py”, line 529, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File “/usr/local/lib/python3.7/site-packages/torch/serialization.py”, line 702, in _legacy_load
result = unpickler.load()
AttributeError: Can’t get attribute ‘RobertaTextList’ on <module ’ main ’ from ‘manage.py’>

The RobertaTextList has been defined in the program but I am still getting the error.

Maybe I have to define this function or import it in the context that I’m loading the databunch. But I don’t know how.

This is the code -

 path = Path()
# Loading the databunch
data = load_data(path, r"data_save.pkl")
roberta_model = CustomRobertaModel()
learn = Learner(data, roberta_model, metrics=[accuracy])
st2 = torch.load(r"final_model_base.pth", map_location=torch.device('cpu'))
learn.model.state_dict(st2)

Can anyone help me with this?

Yes, you do. You can make a utils.py that imports everything you need then just import it when you want to use it. IE from utils import *

But does it help that I already have -

class RobertaTextList(TextList):
_bunch = RobertaDataBunch
_label_cls = TextList

this in the same file.
How do I call this in the same file?

No it does not help. This is not a fastai thing but a basic python thing with the pickle module. When you export your Learner, it saves the names of the functions it needs, so any new thing you use has to be in a separate module (for instance utils.py but you can name it as you like) to be able to load it back in a different environment.

@sgugger @muellerzr

I am still getting an error when I run it with a Celery worker -

“AttributeError: Can’t get attribute ‘RobertaTextList’ on <module ‘celery.bin.celery’ from ‘/usr/local/lib/python3.7/site-packages/celery/bin/celery.py’>”

Do I have to import it in the Celery.py file as well?