Could Learn.save return the saved path?

I have been playing around with Colab config and wanted an easy way to store just my models out without all the data overhead. I have found this addition to the code means I can do so without a lot of changes to paths or the lesson code.

might be of interest

def save(self, name:PathOrStr, path:PathOrStr='', return_path:bool=False, with_opt:bool=True):
        "Save model and optimizer state (if `with_opt`) with `name` to `self.model_dir`."
        "If `path` a custom save dir can be specified"
        if not path: path = self.path/self.model_dir/f'{name}.pth'
        else: path = f'{path}/{name}.pth'
        if not with_opt: state = get_model(self.model).state_dict()
        else: state = {'model': get_model(self.model).state_dict(), 'opt':self.opt.state_dict()}
        torch.save(state, path)
        if return_path: return path

def load(self, name:PathOrStr, path:PathOrStr='', device:torch.device=None, strict:bool=True, with_opt:bool=None):
        "Load model and optimizer state (if `with_opt`) `name` from `self.model_dir` using `device`."
        "If `path` a custom load dir can be specified"
        if device is None: device = self.data.device
        if not path: path = self.path/self.model_dir/f'{name}.pth'
        else: path = f'{path}/{name}.pth'
        state = torch.load(path, map_location=device) #add
        if set(state.keys()) == {'model', 'opt'}:
            get_model(self.model).load_state_dict(state['model'], strict=strict)
            if ifnone(with_opt,True):
                if not hasattr(self, 'opt'): opt = self.create_opt(defaults.lr, self.wd)
                try:    self.opt.load_state_dict(state['opt'])
                except: pass
        else:
            if with_opt: warn("Saved filed doesn't contain an optimizer state.")
            get_model(self.model).load_state_dict(state, strict=strict)
        return self

I have been calling them something else custom_path_save | custom_path_load and just overloading the fastai methods for now.

learn.save = custom_path_save.__get__(learn)
learn.load = custom_path_load.__get__(learn)

Anyway just a thought

3 Likes