Hello all,
I added a simple modification to the SaveModelCallback to include the metric value in the file name. It’s a really simple modification, but I found this extremely helpful for training a model and tracking performance. Hope this helps someone. If people have other tips please share them!
class MySaveModelCallback(TrackerCallback):
"A `TrackerCallback` that saves the model when monitored quantity is best."
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', every:str='improvement', name:str='bestmodel'):
super().__init__(learn, monitor=monitor, mode=mode)
self.every,self.name = every,name
if self.every not in ['improvement', 'epoch']:
warn(f'SaveModel every {self.every} is invalid, falling back to "improvement".')
self.every = 'improvement'
def jump_to_epoch(self, epoch:int)->None:
try:
self.learn.load(f'{self.name}_{epoch-1}', purge=False)
print(f"Loaded {self.name}_{epoch-1}")
except: print(f'Model {self.name}_{epoch-1} not found.')
def on_epoch_end(self, epoch:int, **kwargs:Any)->None:
"Compare the value monitored to its best score and maybe save the model."
current = self.get_monitor_value()
if self.every=="epoch": self.learn.save(f'{self.name}_ep{epoch}_v{current}')
else: #every="improvement"
if current is not None and self.operator(current, self.best):
print(f'Better model found at epoch {epoch} with {self.monitor} value: {current}.')
self.best = current
self.learn.save(f'{self.name}_v{current}')