I am unable to load learner back after having exported it.
I get the following error:-
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-11-ff0632d397f6> in <module>
----> 1 load_learner(Path('/nfs_storage/fs-mnt6/vaibhavg/Hikemoji3D/Assets/PythonScripts/Trained_models/Lips/trp_wt=grp_wt=1_non_progressive_full_face.pkl'))
/nfs_storage/fs-mnt6/vaibhavg/conda3/envs/hikemoji/lib/python3.8/site-packages/fastai/learner.py in load_learner(fname, cpu, pickle_module)
372 "Load a `Learner` object in `fname`, optionally putting it on the `cpu`"
373 distrib_barrier()
--> 374 res = torch.load(fname, map_location='cpu' if cpu else None, pickle_module=pickle_module)
375 if hasattr(res, 'to_fp32'): res = res.to_fp32()
376 if cpu: res.dls.cpu()
/nfs_storage/fs-mnt6/vaibhavg/conda3/envs/hikemoji/lib/python3.8/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
592 opened_file.seek(orig_position)
593 return torch.jit.load(opened_file)
--> 594 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
595 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
596
/nfs_storage/fs-mnt6/vaibhavg/conda3/envs/hikemoji/lib/python3.8/site-packages/torch/serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
851 unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
852 unpickler.persistent_load = persistent_load
--> 853 result = unpickler.load()
854
855 torch._utils._validate_loaded_sparse_tensors()
AttributeError: Can't get attribute 'get_x' on <module '__main__'>
If I use the command load_learner in the same file where it was exported, everything works as expected and I am able to load the learner.
I have tried copying the get_x and get_y functions to the new file (as it was suggested in this post ). However this did not help either.
If it helps, here are my get_x and get_y function definitions:-
def get_x(row):
return Path(row['human_img_path'])
def get_y(row):
return get_lips_params_avatar_imagename(row['winner']), \
get_lips_params_avatar_imagename(row['loser']), \
row['round_num']
Really need help on this.