when I call learner.export(), above error raises.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-21-155ff6ef0adc> in <module>
----> 1 learner._learner.export(pickle_protocol=4)
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in export(self, fname, pickle_module, pickle_protocol)
364 #To avoid the warning that come from PyTorch about model not being checked
365 warnings.simplefilter("ignore")
--> 366 torch.save(self, self.path/fname, pickle_module=pickle_module, pickle_protocol=pickle_protocol)
367 self.create_opt()
368 if state is not None: self.opt.load_state_dict(state)
/opt/conda/lib/python3.7/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)
370 if _use_new_zipfile_serialization:
371 with _open_zipfile_writer(opened_file) as opened_zipfile:
--> 372 _save(obj, opened_zipfile, pickle_module, pickle_protocol)
373 return
374 _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
/opt/conda/lib/python3.7/site-packages/torch/serialization.py in _save(obj, zip_file, pickle_module, pickle_protocol)
474 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
475 pickler.persistent_id = persistent_id
--> 476 pickler.dump(obj)
477 data_value = data_buf.getvalue()
478 zip_file.write_record('data.pkl', data_value, len(data_value))
AttributeError: Can't pickle local object 'RandomSplitter.<locals>._inner'
I have read many threads similiar, but they’re all related to customized lambda functions. However, my issue occurs at RandomeSplitter, it’s a class provided by fastai.
Here is my code:
from math import sqrt
from pathlib import Path
from typing import List, Tuple, Callable
import numpy as np
from fastai.vision.all import *
from torch import nn
import tempfile
class HacktchaLoss(BaseLoss):
def __init__(self, vocab: List[str], nletters:int, axis=-1, **kwargs):
self.func = None
self._vocab = vocab
self._nclasses = len(vocab)
self._nletters = nletters
self.enc_map = {l:e for e,l in enumerate(self._vocab)}
self.dec_map = {e:l for l,e in self.enc_map.items()}
def __call__(self, inp, *y):
preds = inp.split(self._nclasses, dim = 1)
_loss = nn.CrossEntropyLoss()(preds[0], y[0])
for i in range(self._nletters):
_loss += nn.CrossEntropyLoss()(preds[i], y[i])
return _loss
def decodes(self, x):
preds = x.split(self._nclasses, dim = 1)
return [self.dec_map[preds[i].argmax(dim=1)] for i in range(self._nletters)]
class LabellingWrapper():
def __init__(self, pos: int):
self._pos = pos
def __call__(self, filepath: Path):
"""get label from file name
It's assumed that the filename, or parts of it contains the label. The filename should be in either "1234.jpg" or "1234_xxx.jpg" format. The latter form is used when there's multiple image belongs to same label
Args:
filepath (Path): [description]
Returns:
[type]: [description]
"""
label = filepath.parts[-1].split(".")[0]
if label.find("_"):
label = label.split("_")[0]
return label[self._pos]
class HacktchaLearner:
def __init__(self, arch: Callable, nletters:int, vocab: str, image_path:Tuple[str, str] = None):
self._nletters = nletters
self._vocab = vocab
self._bs = 64
self._nclasses = len(self._vocab)
self._arch = arch
self._lr = 1e-3
if image_path is not None:
self._general_path = image_path[0]
self._specific_path = image_path[1]
self._dls = None
self._learner:Learner = None
blocks = (ImageBlock(cls=PILImageBW), *([CategoryBlock] * self._nletters))
self._datasets = DataBlock(
blocks=blocks,
n_inp=1,
get_items=get_image_files,
get_y=self.labelling_funcs,
batch_tfms=[*aug_transforms(do_flip=False), Normalize()],
splitter=RandomSplitter(seed=42))
def create_learner(self, image_path:str, model_path:str=None):
"""create a cnn_learner. If `model_path` exists, then load states from `model_path`
Args:
model_path (str, optional): [description]. Defaults to None.
"""
files = len(os.listdir(image_path))
bs = int(sqrt(sqrt(files)))
self._dls = self._datasets.dataloaders(source=image_path, bs=min(64, bs))
self._learner = cnn_learner(
self._dls,
self._arch,
n_out = (self._nclasses * self._nletters),
loss_func=HacktchaLoss(self._vocab, self._nletters),
lr=self._lr,
metrics=self.accuracy,
cbs = [EarlyStoppingCallback(patience=3)])
if model_path and os.path.exists(model_path):
self._learner.load(model_path)
@property
def labelling_funcs(self):
"""group of `LabellingWrapper` according to `n_letters`
"""
return [LabellingWrapper(i) for i in range(self._nletters)]
def accuracy(self, preds, *y):
"""calcualte accuracy of the prediction
Args:
preds ([type]): [description]
"""
preds = preds.split(self._nclasses, dim=1)
r0 = (preds[0].argmax(dim=1) == y[0]).float().mean()
for i in range(1, self._nletters):
r0 += (preds[i].argmax(dim=1) == y[i]).float().mean()
return r0/self._nletters
def train(self, save_to: str, epoch: int=100):
# do the training on general captcha images dataset
self.create_learner(self._general_path)
lr_min, _ = self._learner.lr_find()
self._learner.freeze_to(-2)
self._learner.fine_tune(epoch, base_lr=lr_min)
_tmp_model = tempfile.mktemp()
self._learner.save(_tmp_model, pickle_protocol = 4)
# fine_tune on specific images found on specific website
self.create_learner(self._specific_path, _tmp_model)
lr_min, _ = self._learner.lr_find()
#self.create_learner(self._specific_path)
self._learner.freeze_to(-1)
self._learner.fine_tune(epoch, base_lr = lr_min)
model_file = os.path.join(save_to, f"hacktcha-{self._arch.__name__}-{self._nletters}-{self._nclasses}")
self._learner.save(model_file)
paths = ("/kaggle/input/captcha-gen-5000", "/kaggle/input/captcha-em-340")
learner = HacktchaLearner(resnet18, 4, "0123456789", paths)
learner.train("/kaggle/working/models/", epoch=1)
learner._learnre.export()
any idea how to solve this problem, or workaroud it?