Trouble saving with SentencePiece tokenizer

I had been able to save the TextClasDataBunch and TextLMDatBunch in the past (pickle them) and then restore and overwrite my processor to get that to work. I think that the method of save/load a DataBunch has changed and now I cannot save/load anymore.

This is a full example that reproduces the issue. SentencePiece is a SWIG object, so I understand why/how it cannot be pickled. However, I would like to save all the other data and then insert the SWIG object when I .load() again in the future.

Did anyone else have this problem? Do they have ideas for a solution?

from fastai.text import *
path = untar_data(URLs.IMDB_SAMPLE)
data = TextClasDataBunch.from_csv(path, 'texts.csv')
data.save('tmp_data')  ## this works, default example

formatted_text_file = './for_sp_train.csv'
pd.read_csv(path/'texts.csv').text.to_frame().to_csv(formatted_text_file, header=False,index=False,quotechar=' ')

import sentencepiece as spm #https://github.com/google/sentencepiece

vocab_size = 600
model_prefix = 'basic_example'

spm.SentencePieceTrainer.Train(f'--input={formatted_text_file}'\
                               f' --model_prefix={model_prefix}'\
                               f' --vocab_size={vocab_size}'\
                               f' --model_type=bpe'\
                               f' --unk_piece={UNK} --bos_piece={BOS} --eos_id=-1 --pad_piece={PAD}')

## load up the Processor
sp = spm.SentencePieceProcessor()
sp.load(f'{model_prefix}.model')

## itos from m.vocab file: just read directly and populate the dictionary
itos = {}
with open(f'{model_prefix}.vocab','r') as f:
    for line_num,line in enumerate(f):
        itos[line_num] = line.split("\t")[0]

        
class CustomTokenizer():
    '''Wrapper for SentencePiece toeknizer to fit into Fast.ai V1'''
    def __init__(self,sp_processor=None,pre_rules:ListRules=None,post_rules:ListRules=None):
        self.sp = sp_processor
        self.pre_rules  = ifnone(pre_rules,  defaults.text_pre_rules )
        
    def __repr__(self) -> str:
        return "Custom Tokenizer"

    def process_text(self, t:str) -> List[str]:
        "Processe one text `t` with tokenizer `tok`."
        for rule in self.pre_rules: t = rule(t)  
        toks = sp.EncodeAsIds(t)
        
        return toks 
    
    def _process_all_1(self,texts:Collection[str]) -> List[List[str]]:
        'Process a list of `texts` in one process'
        return [self.process_text(t) for t in texts]
                                                                     
    def process_all(self, texts:Collection[str]) -> List[List[str]]: 
        "Process a list of `texts`."                                 
        return self._process_all_1(texts)

cust_tok = CustomTokenizer(sp)
sp_vocab = Vocab(itos)
data = TextClasDataBunch.from_csv(path, 'texts.csv',tokenizer=cust_tok,vocab=sp_vocab)
data.save('tmp_data_sp')

And here is my error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-6-c83934da3d6f> in <module>
      2 sp_vocab = Vocab(itos)
      3 data = TextClasDataBunch.from_csv(path, 'texts.csv',tokenizer=cust_tok,vocab=sp_vocab)
----> 4 data.save('tmp_data_sp')

~/fast_ai/fastai-fork/fastai/basic_data.py in save(self, fname)
    152             warn("Serializing the `DataBunch` only works when you created it using the data block API.")
    153             return
--> 154         try_save(self.label_list, self.path, fname)
    155 
    156     def add_test(self, items:Iterator, label:Any=None)->None:

~/fast_ai/fastai-fork/fastai/torch_core.py in try_save(state, path, fname)
    407 
    408 def try_save(state:Dict, path:Path, fname:PathOrStr):
--> 409     try: torch.save(state, open(path/fname, 'wb'))
    410     except OSError as e:
    411         raise Exception(f"{e}\n Can't write {path/fname}. Pass an absolute writable pathlib obj `fname`.")

~/anaconda3/envs/fastaiv1_dev/lib/python3.7/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol)
    216         >>> torch.save(x, buffer)
    217     """
--> 218     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
    219 
    220 

~/anaconda3/envs/fastaiv1_dev/lib/python3.7/site-packages/torch/serialization.py in _with_file_like(f, mode, body)
    141         f = open(f, mode)
    142     try:
--> 143         return body(f)
    144     finally:
    145         if new_fd:

~/anaconda3/envs/fastaiv1_dev/lib/python3.7/site-packages/torch/serialization.py in <lambda>(f)
    216         >>> torch.save(x, buffer)
    217     """
--> 218     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
    219 
    220 

~/anaconda3/envs/fastaiv1_dev/lib/python3.7/site-packages/torch/serialization.py in _save(obj, f, pickle_module, pickle_protocol)
    289     pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
    290     pickler.persistent_id = persistent_id
--> 291     pickler.dump(obj)
    292 
    293     serialized_storage_keys = sorted(serialized_storages.keys())

TypeError: can't pickle SwigPyObject objects

FWIW, I tried this hacky workaround. It did not help out in this case.

1 Like

You should follow the way it’s done for SpacyTokenizer and let the Tokenizer instantiate the tokenizer for you. The function name can be pickled, it’s the object that can’t.

2 Likes

Thanks so much. That worked!

Code for others to reference in the future:

## itos from m.vocab file: just read directly and populate the dictionary
itos = {}
with open(f'{model_prefix}.vocab','r') as f:
    for line_num,line in enumerate(f):
        itos[line_num] = line.split("\t")[0]

        
class SPTokenizer(BaseTokenizer):
    "Wrapper around a SentncePiece tokenizer to make it a `BaseTokenizer`."
    def __init__(self, model_prefix:str):
        self.tok = spm.SentencePieceProcessor()
        self.tok.load(f'{model_prefix}.model')

    def tokenizer(self, t:str) -> List[str]:
        return self.tok.EncodeAsIds(t)
        
class CustomTokenizer():
    '''Wrapper for SentencePiece toeknizer to fit into Fast.ai V1'''
    def __init__(self,tok_func:Callable,model_prefix:str, pre_rules:ListRules=None):
        self.tok_func,self.model_prefix = tok_func,model_prefix
        self.pre_rules  = ifnone(pre_rules,  defaults.text_pre_rules )
        
    def __repr__(self) -> str:
        res = f'Tokenizer {self.tok_func.__name__} using `{self.model_prefix}` model with the following rules:\n'
        for rule in self.pre_rules: res += f' - {rule.__name__}\n'
        return res        

    def process_text(self, t:str,tok:BaseTokenizer) -> List[str]:
        "Processe one text `t` with tokenizer `tok`."
        for rule in self.pre_rules: t = rule(t)  
        toks = tok.tokenizer(t)
    
        return toks 
    
    def _process_all_1(self,texts:Collection[str]) -> List[List[str]]:
        'Process a list of `texts` in one process'
        tok = self.tok_func(self.model_prefix)
        return [self.process_text(t,tok) for t in texts]
                                                                     
    def process_all(self, texts:Collection[str]) -> List[List[str]]: 
        "Process a list of `texts`."                                 
        return self._process_all_1(texts)

cust_tok = CustomTokenizer(SPTokenizer,model_prefix)#CustomTokenizer(sp.EncodeAsIds)
sp_vocab = Vocab(itos)
data = TextClasDataBunch.from_csv(path, 'texts.csv',tokenizer=cust_tok,vocab=sp_vocab)
data.save('tmp_data_sp')
6 Likes