SaveModelCallback Saves Metric Score in Filename

Hello all,

I added a simple modification to the SaveModelCallback to include the metric value in the file name. It’s a really simple modification, but I found this extremely helpful for training a model and tracking performance. Hope this helps someone. If people have other tips please share them! :slight_smile:

class MySaveModelCallback(TrackerCallback):
   "A `TrackerCallback` that saves the model when monitored quantity is best."
   def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', every:str='improvement', name:str='bestmodel'):
       super().__init__(learn, monitor=monitor, mode=mode)
       self.every,self.name = every,name
       if self.every not in ['improvement', 'epoch']:
           warn(f'SaveModel every {self.every} is invalid, falling back to "improvement".')
           self.every = 'improvement'
                
   def jump_to_epoch(self, epoch:int)->None:
       try: 
           self.learn.load(f'{self.name}_{epoch-1}', purge=False)
           print(f"Loaded {self.name}_{epoch-1}")
       except: print(f'Model {self.name}_{epoch-1} not found.')

   def on_epoch_end(self, epoch:int, **kwargs:Any)->None:
       "Compare the value monitored to its best score and maybe save the model."
       current = self.get_monitor_value()
       if self.every=="epoch": self.learn.save(f'{self.name}_ep{epoch}_v{current}')
       else: #every="improvement"
           
           if current is not None and self.operator(current, self.best):
               print(f'Better model found at epoch {epoch} with {self.monitor} value: {current}.')
               self.best = current
               self.learn.save(f'{self.name}_v{current}')
3 Likes