Loading the learner with custom model - Transformers

I have successfully exported a model already created with transformers and fastai.
I know that when loading the model I need to redefine the custom module like this :

class CustomTransformerModel(nn.Module):
    def __init__(self, transformer_model: PreTrainedModel):
        super(CustomTransformerModel,self).__init__()
        self.transformer = transformer_model

    def forward(self, input_ids):
        # Return only the logits from the transfomer
        logits = self.transformer(input_ids)[0]   
        return logits

everything works fine and I can get prediction using the learner as follows :

learner = load_learner(’’)
outputs = learner.predict(“the movie was good”)

but when trying to use Flask , when importing the CustomTransformerModel class the learner don’t find it.
my flask code is as follows :

from transformers import BertTokenizer
from transformers import PreTrainedModel

from fastai.text import *

app = Flask(__name__)

# Display all things
@app.route('/')
def showMain():
    return render_template('classifier.html')

@app.route('/predict',methods=['POST'])
def handler():

    #defaults.device = torch.device('cpu')

    learner = load_learner('')
    #outputs = learner.predict("i like it ,it is so healthy product")

    review=request.args['text']
    pred_class,pred_idx,outputs = learner.predict(review)
    return render_template('classifier.html', prediction_text='Your Input : {} Prediction : {} '.format(review,outputs))

class FastAiBertTokenizer(BaseTokenizer):
    """Wrapper around BertTokenizer to be compatible with fast.ai"""
    def __init__(self, tokenizer: BertTokenizer, max_seq_len: int=128, **kwargs):
        self._pretrained_tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __call__(self, *args, **kwargs):
        return self

    def tokenizer(self, t:str) -> List[str]:
        """Limits the maximum sequence length"""
        return ["[CLS]"] + self._pretrained_tokenizer.tokenize(t)[:self.max_seq_len - 2] + ["[SEP]"]

class CustomTransformerModel(nn.Module):
    def __init__(self, transformer_model: PreTrainedModel):
        super(CustomTransformerModel,self).__init__()
        self.transformer = transformer_model

    def forward(self, input_ids):
        # Return only the logits from the transfomer
        logits = self.transformer(input_ids)[0]
        return logits


if __name__ == '__main__':
    app.secret_key = 'super_secret_key'
    app.debug = True
    app.run(host='0.0.0.0', port=8000)

I get an error as follows :

File “C:\flask-starter-project\app.py”, line 50, in handler
learner = load_learner(’’)
File “c:\programdata\anaconda3\envs\fastai_v1\lib\site-packages\fastai\basic_train.py”, line 618, in load_learner
state = torch.load(source, map_location=‘cpu’) if defaults.device == torch.device(‘cpu’) else torch.load(source)
File “c:\programdata\anaconda3\envs\fastai_v1\lib\site-packages\torch\serialization.py”, line 529, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File “c:\programdata\anaconda3\envs\fastai_v1\lib\site-packages\torch\serialization.py”, line 702, in _legacy_load
result = unpickler.load()
AttributeError: Can’t get attribute ‘CustomTransformerModel’ on <module ‘main’ from ‘C:\ProgramData\Anaconda3\envs\fastai_v1\Scripts\flask.exe\main.py’>

I think that is due to my class CustomTransformerModel is now a class named app.CustomTransformerModel and the load_learner() function is searching for a class CustomTransformerModel and it throws that error .

Any solution ?

For deployment, you should put all your custom layers/loss functions and things like that in a separate module, like a file named utils.py in the same folder, then you can use import utils.

Carefully use this approach to export your Learner, so python knows to look at the code in utils, then place that file in your deployment env and it should work.

you mean that in the training fase I should use the utils.py , so that when deploying I will do the same ? or I am missing something ?

Hi,

Were you able to find a solution for this? I am getting the same error.

Yes, as suggested @sgugger, I have used an external module named utils.py where I have putted the class that I use, and I have imported the module, and retrained the module and export it, now when you reload the model, you need to make an import of the utils.py so when the pickle will look for the class, and it’s going to be already compiled (because you imported it).

@zack404 If you don’t mind, do you have any reference code I could look at to get a clearer idea of the solution?