I did some research and overloaded the save and load methods with this code.
Anyone care to comment? @jeremy
def custom_path_save(self, name:PathOrStr, path='', return_path:bool=False, with_opt:bool=True):
"Save model and optimizer state (if `with_opt`) with `name` to `self.model_dir`."
# delete # path = self.path/self.model_dir/f'{name}.pth'
# my addition: start
if path=='': path = self.path/self.model_dir/f'{name}.pth'
else: path = f'{path}/{name}.pth'
# end
if not with_opt: state = get_model(self.model).state_dict()
else: state = {'model': get_model(self.model).state_dict(), 'opt':self.opt.state_dict()}
torch.save(state, path)
if return_path: return path
def custom_path_load(self, name:PathOrStr, path='', device:torch.device=None, strict:bool=True, with_opt:bool=None):
"Load model and optimizer state (if `with_opt`) `name` from `self.model_dir` using `device`."
if device is None: device = self.data.device
# delete # state = torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device)
# my addition: start
if path=='': path = self.path/self.model_dir/f'{name}.pth'
else: path = f'{path}/{name}.pth'
state = torch.load(path, map_location=device)
# end
if set(state.keys()) == {'model', 'opt'}:
get_model(self.model).load_state_dict(state['model'], strict=strict)
if ifnone(with_opt,True):
if not hasattr(self, 'opt'): opt = self.create_opt(defaults.lr, self.wd)
try: self.opt.load_state_dict(state['opt'])
except: pass
else:
if with_opt: warn("Saved filed doesn't contain an optimizer state.")
get_model(self.model).load_state_dict(state, strict=strict)
return self
learn.save = custom_path_save.__get__(learn)
learn.load = custom_path_load.__get__(learn)
# if you don't want to overload
#learn.custom_path_save = custom_path_save.__get__(learn)
#learn.custom_path_load = custom_path_load.__get__(learn)
model_path = '/content/gdrive/My Drive/fastai-v3/data/'
learn.save('new-model-name', path=model_path)
learn.load('new-model-name', path=model_path)