I had been able to save the TextClasDataBunch
and TextLMDatBunch
in the past (pickle them) and then restore and overwrite my processor
to get that to work. I think that the method of save/load a DataBunch has changed and now I cannot save/load anymore.
This is a full example that reproduces the issue. SentencePiece is a SWIG object, so I understand why/how it cannot be pickled. However, I would like to save all the other data and then insert the SWIG object when I .load()
again in the future.
Did anyone else have this problem? Do they have ideas for a solution?
from fastai.text import *
path = untar_data(URLs.IMDB_SAMPLE)
data = TextClasDataBunch.from_csv(path, 'texts.csv')
data.save('tmp_data') ## this works, default example
formatted_text_file = './for_sp_train.csv'
pd.read_csv(path/'texts.csv').text.to_frame().to_csv(formatted_text_file, header=False,index=False,quotechar=' ')
import sentencepiece as spm #https://github.com/google/sentencepiece
vocab_size = 600
model_prefix = 'basic_example'
spm.SentencePieceTrainer.Train(f'--input={formatted_text_file}'\
f' --model_prefix={model_prefix}'\
f' --vocab_size={vocab_size}'\
f' --model_type=bpe'\
f' --unk_piece={UNK} --bos_piece={BOS} --eos_id=-1 --pad_piece={PAD}')
## load up the Processor
sp = spm.SentencePieceProcessor()
sp.load(f'{model_prefix}.model')
## itos from m.vocab file: just read directly and populate the dictionary
itos = {}
with open(f'{model_prefix}.vocab','r') as f:
for line_num,line in enumerate(f):
itos[line_num] = line.split("\t")[0]
class CustomTokenizer():
'''Wrapper for SentencePiece toeknizer to fit into Fast.ai V1'''
def __init__(self,sp_processor=None,pre_rules:ListRules=None,post_rules:ListRules=None):
self.sp = sp_processor
self.pre_rules = ifnone(pre_rules, defaults.text_pre_rules )
def __repr__(self) -> str:
return "Custom Tokenizer"
def process_text(self, t:str) -> List[str]:
"Processe one text `t` with tokenizer `tok`."
for rule in self.pre_rules: t = rule(t)
toks = sp.EncodeAsIds(t)
return toks
def _process_all_1(self,texts:Collection[str]) -> List[List[str]]:
'Process a list of `texts` in one process'
return [self.process_text(t) for t in texts]
def process_all(self, texts:Collection[str]) -> List[List[str]]:
"Process a list of `texts`."
return self._process_all_1(texts)
cust_tok = CustomTokenizer(sp)
sp_vocab = Vocab(itos)
data = TextClasDataBunch.from_csv(path, 'texts.csv',tokenizer=cust_tok,vocab=sp_vocab)
data.save('tmp_data_sp')
And here is my error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-6-c83934da3d6f> in <module>
2 sp_vocab = Vocab(itos)
3 data = TextClasDataBunch.from_csv(path, 'texts.csv',tokenizer=cust_tok,vocab=sp_vocab)
----> 4 data.save('tmp_data_sp')
~/fast_ai/fastai-fork/fastai/basic_data.py in save(self, fname)
152 warn("Serializing the `DataBunch` only works when you created it using the data block API.")
153 return
--> 154 try_save(self.label_list, self.path, fname)
155
156 def add_test(self, items:Iterator, label:Any=None)->None:
~/fast_ai/fastai-fork/fastai/torch_core.py in try_save(state, path, fname)
407
408 def try_save(state:Dict, path:Path, fname:PathOrStr):
--> 409 try: torch.save(state, open(path/fname, 'wb'))
410 except OSError as e:
411 raise Exception(f"{e}\n Can't write {path/fname}. Pass an absolute writable pathlib obj `fname`.")
~/anaconda3/envs/fastaiv1_dev/lib/python3.7/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol)
216 >>> torch.save(x, buffer)
217 """
--> 218 return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
219
220
~/anaconda3/envs/fastaiv1_dev/lib/python3.7/site-packages/torch/serialization.py in _with_file_like(f, mode, body)
141 f = open(f, mode)
142 try:
--> 143 return body(f)
144 finally:
145 if new_fd:
~/anaconda3/envs/fastaiv1_dev/lib/python3.7/site-packages/torch/serialization.py in <lambda>(f)
216 >>> torch.save(x, buffer)
217 """
--> 218 return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
219
220
~/anaconda3/envs/fastaiv1_dev/lib/python3.7/site-packages/torch/serialization.py in _save(obj, f, pickle_module, pickle_protocol)
289 pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
290 pickler.persistent_id = persistent_id
--> 291 pickler.dump(obj)
292
293 serialized_storage_keys = sorted(serialized_storages.keys())
TypeError: can't pickle SwigPyObject objects
FWIW, I tried this hacky workaround. It did not help out in this case.