I am trying to export a model but am getting a pickle error. I know that lambda functions cannot be pickled and I have followed advice on Fastai v2 Recipes (Tips and Tricks) - Wiki - #2 by farid for using a named function instead. However, I am still getting:
AttributeError: Can't pickle local object 'get_dls.<locals>.get_y'
Here’s the trace
AttributeError Traceback (most recent call last)
----> learn.export(filesavename+'_export.pkl')
/usr/local/lib/python3.7/dist-packages/fastai/learner.py in export(self, fname, pickle_module, pickle_protocol)
373 #To avoid the warning that come from PyTorch about model not being checked
374 warnings.simplefilter("ignore")
--> 375 torch.save(self, self.path/fname, pickle_module=pickle_module, pickle_protocol=pickle_protocol)
376 self.create_opt()
377 if state is not None: self.opt.load_state_dict(state)
/usr/local/lib/python3.7/dist-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)
377 if _use_new_zipfile_serialization:
378 with _open_zipfile_writer(opened_file) as opened_zipfile:
--> 379 _save(obj, opened_zipfile, pickle_module, pickle_protocol)
380 return
381 _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
/usr/local/lib/python3.7/dist-packages/torch/serialization.py in _save(obj, zip_file, pickle_module, pickle_protocol)
482 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
483 pickler.persistent_id = persistent_id
--> 484 pickler.dump(obj)
485 data_value = data_buf.getvalue()
486 zip_file.write_record('data.pkl', data_value, len(data_value))
My data loader is defined as follows:
def get_dls(bs:int, size:int):
path=Path('/content/gdrive/MyDrive/encoder_dataset')
path_src = path/'src'
path_tar = path/'tar'
def get_y(o):
return path_tar/o.name
dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
get_items=get_image_files,
get_y = get_y, #lambda x: path_tar/x.name,
splitter=RandomSplitter(),
item_tfms=Resize(size),
batch_tfms=[*aug_transforms(),
Normalize.from_stats(*imagenet_stats)])
dls = dblock.dataloaders(path_src, bs=bs, path=path)
dls.c = 3
return dls
And definition of learner:
data = get_dls(224,224)
learn = unet_learner(data, resnet34, loss_func=F.mse_loss)
Everything works fine for training. Can someone please point out my mistake?