I have a notebook with a image segmentation problem. I am trying to export a base model and reload it later for inferences and further training. However, The learner.export() seems to have a problem, specifically with a get_y function defined for returning masks in the datablock API.
The error returned is:
AttributeError: Can’t pickle local object ‘get_y..get_msk’
The function is as follows:
def get_y(clas_dic):
def get_msk(fn):
mask = masks(fn)
mask_img=PILMask.create(mask)
mask_tensor = tensor(mask_img)
for i in vals:
mask_tensor[mask_tensor == i] = clas_dic[i]
return mask_tensor
return get_msk
And the datablock is defined as follows:
dblock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=codes)),
get_items = items,
get_y = get_y(p2d),
splitter = custom_split(0.5),
item_tfms=[Resize(128)],
batch_tfms =[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])
I would appreciate it if someone could take a look at the notebook and give an idea of the problem. Thanks!