Ccan't pickle local object 'RandomSplitter.<locals>._inner'

when I call learner.export(), above error raises.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-21-155ff6ef0adc> in <module>
----> 1 learner._learner.export(pickle_protocol=4)

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in export(self, fname, pickle_module, pickle_protocol)
    364         #To avoid the warning that come from PyTorch about model not being checked
    365         warnings.simplefilter("ignore")
--> 366         torch.save(self, self.path/fname, pickle_module=pickle_module, pickle_protocol=pickle_protocol)
    367     self.create_opt()
    368     if state is not None: self.opt.load_state_dict(state)

/opt/conda/lib/python3.7/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)
    370         if _use_new_zipfile_serialization:
    371             with _open_zipfile_writer(opened_file) as opened_zipfile:
--> 372                 _save(obj, opened_zipfile, pickle_module, pickle_protocol)
    373                 return
    374         _legacy_save(obj, opened_file, pickle_module, pickle_protocol)

/opt/conda/lib/python3.7/site-packages/torch/serialization.py in _save(obj, zip_file, pickle_module, pickle_protocol)
    474     pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
    475     pickler.persistent_id = persistent_id
--> 476     pickler.dump(obj)
    477     data_value = data_buf.getvalue()
    478     zip_file.write_record('data.pkl', data_value, len(data_value))

AttributeError: Can't pickle local object 'RandomSplitter.<locals>._inner'

I have read many threads similiar, but they’re all related to customized lambda functions. However, my issue occurs at RandomeSplitter, it’s a class provided by fastai.

Here is my code:

from math import sqrt
from pathlib import Path
from typing import List, Tuple, Callable
import numpy as np
from fastai.vision.all import *
from torch import nn
import tempfile

class HacktchaLoss(BaseLoss):
    def __init__(self, vocab: List[str], nletters:int, axis=-1, **kwargs):
        self.func = None
        self._vocab = vocab
        self._nclasses = len(vocab)
        self._nletters = nletters

        self.enc_map = {l:e for e,l in enumerate(self._vocab)}
        self.dec_map = {e:l for l,e in self.enc_map.items()}

    def __call__(self, inp, *y):
        preds = inp.split(self._nclasses, dim = 1)
        
        _loss = nn.CrossEntropyLoss()(preds[0], y[0])
        for i in range(self._nletters):
            _loss += nn.CrossEntropyLoss()(preds[i], y[i])
    
        return _loss

    def decodes(self, x):
        preds = x.split(self._nclasses, dim = 1)
        return [self.dec_map[preds[i].argmax(dim=1)] for i in range(self._nletters)]

class LabellingWrapper():
    def __init__(self, pos: int):
        self._pos = pos
        
    def __call__(self, filepath: Path):
        """get label from file name

        It's assumed that the filename, or parts of it contains the label. The filename should be in either "1234.jpg" or "1234_xxx.jpg" format. The latter form is used when there's multiple image belongs to same label

        Args:
            filepath (Path): [description]

        Returns:
            [type]: [description]
        """
        label = filepath.parts[-1].split(".")[0]
    
        if label.find("_"):
            label = label.split("_")[0]
            
        return label[self._pos]

class HacktchaLearner:
    def __init__(self, arch: Callable, nletters:int, vocab: str, image_path:Tuple[str, str] = None):
        self._nletters = nletters
        self._vocab = vocab
        self._bs = 64
        self._nclasses = len(self._vocab)
        self._arch = arch
        self._lr = 1e-3
        
        if image_path is not None:
            self._general_path = image_path[0]
            self._specific_path = image_path[1]

        self._dls = None
        self._learner:Learner = None

        blocks = (ImageBlock(cls=PILImageBW), *([CategoryBlock] * self._nletters))

        self._datasets = DataBlock(
            blocks=blocks, 
            n_inp=1, 
            get_items=get_image_files, 
            get_y=self.labelling_funcs,
            batch_tfms=[*aug_transforms(do_flip=False), Normalize()],
           splitter=RandomSplitter(seed=42))


    def create_learner(self, image_path:str, model_path:str=None):
        """create a cnn_learner. If `model_path` exists, then load states from `model_path`

        Args:
            model_path (str, optional): [description]. Defaults to None.
        """
        files = len(os.listdir(image_path))
        bs = int(sqrt(sqrt(files)))

        self._dls = self._datasets.dataloaders(source=image_path, bs=min(64, bs))

        self._learner = cnn_learner(
            self._dls, 
            self._arch, 
            n_out = (self._nclasses * self._nletters), 
            loss_func=HacktchaLoss(self._vocab, self._nletters), 
            lr=self._lr,
            metrics=self.accuracy,
            cbs = [EarlyStoppingCallback(patience=3)])

        if model_path and os.path.exists(model_path):
            self._learner.load(model_path)
    
    @property
    def labelling_funcs(self):
        """group of `LabellingWrapper` according to `n_letters`
        """
        return [LabellingWrapper(i) for i in range(self._nletters)]

    def accuracy(self, preds, *y):
        """calcualte accuracy of the prediction

        Args:
            preds ([type]): [description]
        """
        preds = preds.split(self._nclasses, dim=1)

        r0 = (preds[0].argmax(dim=1) == y[0]).float().mean()
        for i in range(1, self._nletters):
            r0 += (preds[i].argmax(dim=1) == y[i]).float().mean()

        return r0/self._nletters

    def train(self, save_to: str, epoch: int=100):
        # do the training on general captcha images dataset
        self.create_learner(self._general_path)
        lr_min, _ = self._learner.lr_find()
        
        self._learner.freeze_to(-2)
        self._learner.fine_tune(epoch, base_lr=lr_min)

        _tmp_model = tempfile.mktemp()
        self._learner.save(_tmp_model, pickle_protocol = 4)

        # fine_tune on specific images found on specific website
        self.create_learner(self._specific_path, _tmp_model)
        lr_min, _ = self._learner.lr_find()

        #self.create_learner(self._specific_path)

        self._learner.freeze_to(-1)
        self._learner.fine_tune(epoch, base_lr = lr_min)

        model_file = os.path.join(save_to, f"hacktcha-{self._arch.__name__}-{self._nletters}-{self._nclasses}")
        self._learner.save(model_file)
        
paths = ("/kaggle/input/captcha-gen-5000", "/kaggle/input/captcha-em-340")
learner = HacktchaLearner(resnet18, 4, "0123456789", paths)

learner.train("/kaggle/working/models/", epoch=1)
learner._learnre.export()

any idea how to solve this problem, or workaroud it?

If I just remove splitter = RandomSplitter(seed42), it’ll export just fine and the model seems to work.

So what’s difference if have/havn’t splitter specified?

The pickle module doesnt know how to pickle the RandomSplitter type.

from fastai.data.transforms import RandomSplitter
import pickle
pickle.dumps(RandomSplitter())

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-4260ef7da67d> in <module>
----> 1 pickle.dumps(RandomSplitter())

AttributeError: Can't pickle local object 'RandomSplitter.<locals>._inner'

To make RandomSplitter work with pickle, see here.

Hi idraja:

Thanks!

How couldn’t I find the same way to identify the root cause.

Have submitted this as a bug to fastai. And to workaround this issue, I guess I need to write my own splitter.