Could Learn.save return the saved path?

The Learn.save method now saves the model into a specific file:

class Learner:
    ...
    def save(self, name:PathOrStr):
        "Save model with `name` to `self.model_dir`."
        torch.save(self.model.state_dict(), self.path/self.model_dir/f'{name}.pth')

However, it doesn’t return the saved file path. Of course, one can re-create the path manually using learn object properties. However, it would be great if learner could return the path of the saved model.

For example:

learn = ... # create and train learner
path = learn.save('final_model')
log.info('The trained model was saved into: %s', path)
copy_model_into_folder(src=path, dst='/var/www/host/models')

Probably not a big deal, just a conviniece for someone who writes CLI scripts a lot :slight_smile:

1 Like

Done! (in master)

    def save(self, name:PathOrStr, return_path:bool=False)->Union[None,str]:
        "Save model with `name` to `self.model_dir`, and return path if `return_path`."
        path = self.path/self.model_dir/f'{name}.pth'
        torch.save(self.model.state_dict(), path)
        if return_path: return path
8 Likes

I have been playing around with Colab config and wanted an easy way to store just my models out without all the data overhead. I have found this addition to the code means I can do so without a lot of changes to paths or the lesson code.

might be of interest

def save(self, name:PathOrStr, path:PathOrStr='', return_path:bool=False, with_opt:bool=True):
        "Save model and optimizer state (if `with_opt`) with `name` to `self.model_dir`."
        "If `path` a custom save dir can be specified"
        if not path: path = self.path/self.model_dir/f'{name}.pth'
        else: path = f'{path}/{name}.pth'
        if not with_opt: state = get_model(self.model).state_dict()
        else: state = {'model': get_model(self.model).state_dict(), 'opt':self.opt.state_dict()}
        torch.save(state, path)
        if return_path: return path

def load(self, name:PathOrStr, path:PathOrStr='', device:torch.device=None, strict:bool=True, with_opt:bool=None):
        "Load model and optimizer state (if `with_opt`) `name` from `self.model_dir` using `device`."
        "If `path` a custom load dir can be specified"
        if device is None: device = self.data.device
        if not path: path = self.path/self.model_dir/f'{name}.pth'
        else: path = f'{path}/{name}.pth'
        state = torch.load(path, map_location=device) #add
        if set(state.keys()) == {'model', 'opt'}:
            get_model(self.model).load_state_dict(state['model'], strict=strict)
            if ifnone(with_opt,True):
                if not hasattr(self, 'opt'): opt = self.create_opt(defaults.lr, self.wd)
                try:    self.opt.load_state_dict(state['opt'])
                except: pass
        else:
            if with_opt: warn("Saved filed doesn't contain an optimizer state.")
            get_model(self.model).load_state_dict(state, strict=strict)
        return self

I have been calling them something else custom_path_save | custom_path_load and just overloading the fastai methods for now.

learn.save = custom_path_save.__get__(learn)
learn.load = custom_path_load.__get__(learn)

Anyway just a thought

3 Likes

This is super helpful.
I am currently working on a kaggle kernel and it doesn’t seem to allow saving/loading models outside of the kernel “input” directory.
Because the kernel workbook is not in the “input” directory then it is necessary to be able to save the model parameters in that specific folder to retrieve them.
Changing the save/load path should make this possible

Yes, agree, it would be great to have a bit more flexible control over path variable. Would be very convenient to have a possibility to override this parameter.

I don’t get what’s inflexible. When creating your Learner you can pass model_dir (‘models’ by default) and if you pass a Path object, it will overwrite the data path.

2 Likes

Sorry, I guess it was a bit too long since I’ve updated my fastai library. I believe that now everything is fine. I just was thinking that there is still a hardcoded models/ folder path.

2 Likes