Platform: Colab ✅

I did some research and overloaded the save and load methods with this code.
Anyone care to comment? :slight_smile: @jeremy

def custom_path_save(self, name:PathOrStr, path='', return_path:bool=False, with_opt:bool=True):
        "Save model and optimizer state (if `with_opt`) with `name` to `self.model_dir`."
        # delete #  path = self.path/self.model_dir/f'{name}.pth'
        # my addition: start
        if path=='': path = self.path/self.model_dir/f'{name}.pth'
        else: path = f'{path}/{name}.pth'
        # end
        if not with_opt: state = get_model(self.model).state_dict()
        else: state = {'model': get_model(self.model).state_dict(), 'opt':self.opt.state_dict()}
        torch.save(state, path)
        if return_path: return path

def custom_path_load(self, name:PathOrStr, path='', device:torch.device=None, strict:bool=True, with_opt:bool=None):
        "Load model and optimizer state (if `with_opt`) `name` from `self.model_dir` using `device`."
        if device is None: device = self.data.device
        # delete # state = torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device)
        # my addition: start
        if path=='': path = self.path/self.model_dir/f'{name}.pth'
        else: path = f'{path}/{name}.pth'
        state = torch.load(path, map_location=device) 
        # end
        if set(state.keys()) == {'model', 'opt'}:
            get_model(self.model).load_state_dict(state['model'], strict=strict)
            if ifnone(with_opt,True):
                if not hasattr(self, 'opt'): opt = self.create_opt(defaults.lr, self.wd)
                try:    self.opt.load_state_dict(state['opt'])
                except: pass
        else:
            if with_opt: warn("Saved filed doesn't contain an optimizer state.")
            get_model(self.model).load_state_dict(state, strict=strict)
        return self

learn.save = custom_path_save.__get__(learn)
learn.load = custom_path_load.__get__(learn)
# if you don't want to overload
#learn.custom_path_save = custom_path_save.__get__(learn)
#learn.custom_path_load = custom_path_load.__get__(learn)


model_path = '/content/gdrive/My Drive/fastai-v3/data/'
learn.save('new-model-name', path=model_path)
learn.load('new-model-name', path=model_path)
1 Like