Ccan't pickle local object 'RandomSplitter.<locals>._inner'

when I call learner.export(), above error raises.

AttributeError                            Traceback (most recent call last)
<ipython-input-21-155ff6ef0adc> in <module>
----> 1 learner._learner.export(pickle_protocol=4)

/opt/conda/lib/python3.7/site-packages/fastai/ in export(self, fname, pickle_module, pickle_protocol)
    364         #To avoid the warning that come from PyTorch about model not being checked
    365         warnings.simplefilter("ignore")
--> 366, self.path/fname, pickle_module=pickle_module, pickle_protocol=pickle_protocol)
    367     self.create_opt()
    368     if state is not None: self.opt.load_state_dict(state)

/opt/conda/lib/python3.7/site-packages/torch/ in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)
    370         if _use_new_zipfile_serialization:
    371             with _open_zipfile_writer(opened_file) as opened_zipfile:
--> 372                 _save(obj, opened_zipfile, pickle_module, pickle_protocol)
    373                 return
    374         _legacy_save(obj, opened_file, pickle_module, pickle_protocol)

/opt/conda/lib/python3.7/site-packages/torch/ in _save(obj, zip_file, pickle_module, pickle_protocol)
    474     pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
    475     pickler.persistent_id = persistent_id
--> 476     pickler.dump(obj)
    477     data_value = data_buf.getvalue()
    478     zip_file.write_record('data.pkl', data_value, len(data_value))

AttributeError: Can't pickle local object 'RandomSplitter.<locals>._inner'

I have read many threads similiar, but they’re all related to customized lambda functions. However, my issue occurs at RandomeSplitter, it’s a class provided by fastai.

Here is my code:

from math import sqrt
from pathlib import Path
from typing import List, Tuple, Callable
import numpy as np
from import *
from torch import nn
import tempfile

class HacktchaLoss(BaseLoss):
    def __init__(self, vocab: List[str], nletters:int, axis=-1, **kwargs):
        self.func = None
        self._vocab = vocab
        self._nclasses = len(vocab)
        self._nletters = nletters

        self.enc_map = {l:e for e,l in enumerate(self._vocab)}
        self.dec_map = {e:l for l,e in self.enc_map.items()}

    def __call__(self, inp, *y):
        preds = inp.split(self._nclasses, dim = 1)
        _loss = nn.CrossEntropyLoss()(preds[0], y[0])
        for i in range(self._nletters):
            _loss += nn.CrossEntropyLoss()(preds[i], y[i])
        return _loss

    def decodes(self, x):
        preds = x.split(self._nclasses, dim = 1)
        return [self.dec_map[preds[i].argmax(dim=1)] for i in range(self._nletters)]

class LabellingWrapper():
    def __init__(self, pos: int):
        self._pos = pos
    def __call__(self, filepath: Path):
        """get label from file name

        It's assumed that the filename, or parts of it contains the label. The filename should be in either "1234.jpg" or "1234_xxx.jpg" format. The latter form is used when there's multiple image belongs to same label

            filepath (Path): [description]

            [type]: [description]
        label =[-1].split(".")[0]
        if label.find("_"):
            label = label.split("_")[0]
        return label[self._pos]

class HacktchaLearner:
    def __init__(self, arch: Callable, nletters:int, vocab: str, image_path:Tuple[str, str] = None):
        self._nletters = nletters
        self._vocab = vocab
        self._bs = 64
        self._nclasses = len(self._vocab)
        self._arch = arch
        self._lr = 1e-3
        if image_path is not None:
            self._general_path = image_path[0]
            self._specific_path = image_path[1]

        self._dls = None
        self._learner:Learner = None

        blocks = (ImageBlock(cls=PILImageBW), *([CategoryBlock] * self._nletters))

        self._datasets = DataBlock(
            batch_tfms=[*aug_transforms(do_flip=False), Normalize()],

    def create_learner(self, image_path:str, model_path:str=None):
        """create a cnn_learner. If `model_path` exists, then load states from `model_path`

            model_path (str, optional): [description]. Defaults to None.
        files = len(os.listdir(image_path))
        bs = int(sqrt(sqrt(files)))

        self._dls = self._datasets.dataloaders(source=image_path, bs=min(64, bs))

        self._learner = cnn_learner(
            n_out = (self._nclasses * self._nletters), 
            loss_func=HacktchaLoss(self._vocab, self._nletters), 
            cbs = [EarlyStoppingCallback(patience=3)])

        if model_path and os.path.exists(model_path):
    def labelling_funcs(self):
        """group of `LabellingWrapper` according to `n_letters`
        return [LabellingWrapper(i) for i in range(self._nletters)]

    def accuracy(self, preds, *y):
        """calcualte accuracy of the prediction

            preds ([type]): [description]
        preds = preds.split(self._nclasses, dim=1)

        r0 = (preds[0].argmax(dim=1) == y[0]).float().mean()
        for i in range(1, self._nletters):
            r0 += (preds[i].argmax(dim=1) == y[i]).float().mean()

        return r0/self._nletters

    def train(self, save_to: str, epoch: int=100):
        # do the training on general captcha images dataset
        lr_min, _ = self._learner.lr_find()
        self._learner.fine_tune(epoch, base_lr=lr_min)

        _tmp_model = tempfile.mktemp(), pickle_protocol = 4)

        # fine_tune on specific images found on specific website
        self.create_learner(self._specific_path, _tmp_model)
        lr_min, _ = self._learner.lr_find()


        self._learner.fine_tune(epoch, base_lr = lr_min)

        model_file = os.path.join(save_to, f"hacktcha-{self._arch.__name__}-{self._nletters}-{self._nclasses}")
paths = ("/kaggle/input/captcha-gen-5000", "/kaggle/input/captcha-em-340")
learner = HacktchaLearner(resnet18, 4, "0123456789", paths)

learner.train("/kaggle/working/models/", epoch=1)

any idea how to solve this problem, or workaroud it?

If I just remove splitter = RandomSplitter(seed42), it’ll export just fine and the model seems to work.

So what’s difference if have/havn’t splitter specified?

The pickle module doesnt know how to pickle the RandomSplitter type.

from import RandomSplitter
import pickle

AttributeError                            Traceback (most recent call last)
<ipython-input-4-4260ef7da67d> in <module>
----> 1 pickle.dumps(RandomSplitter())

AttributeError: Can't pickle local object 'RandomSplitter.<locals>._inner'

To make RandomSplitter work with pickle, see here.

Hi idraja:


How couldn’t I find the same way to identify the root cause.

Have submitted this as a bug to fastai. And to workaround this issue, I guess I need to write my own splitter.