I have faced a somewhat similar situation.
I created a logger callback that saves training progress to CSV. I can then load the CSV even if the server shuts down. But if it doesn’t shut down, once you reconnect, you can just open a new cell, do logger.plot_metric
and you will see the results of the training (even on older version of jupyter).
The code is nasty and I think I wrote it half asleep Not sure if it works with vanilla fastai lib though but can be modified. Add to that periodically saving the model (or saving on best val results) and your work is never lost.
Also, strangely enough (and I know cause I run an older version on AWS and newer locally), newer version of Jupyter notebook seems to replay sent messages if they do not reach your notebook! Meaning your notebook will update if you have it open and you reconnect!
These are the versions that I seem to be running on my own box:
jupyter 1.0.0 py36h9896ce5_0
jupyter_client 5.1.0 py36h614e9ea_0
jupyter_console 5.2.0 py36he59e554_1
jupyter_contrib_core 0.3.3 py36_0 conda-forge
jupyter_contrib_nbextensions 0.3.3 py36_0 conda-forge
jupyter_core 4.4.0 py36h7c827e3_0
jupyter_highlight_selected_word 0.1.0 py36_0 conda-forge
jupyter_latex_envs 1.3.8.2 py36_1 conda-forge
jupyter_nbextensions_configurator 0.2.8 py36_0 conda-forge
jupyterlab 0.30.6 py36h2f9c1c0_0
jupyterlab_launcher 0.6.0 py36_0
I think conda update --all
should give you something similar.
For dropped ssh connections I use autossh to automatically reconnect.
Here is the code and sorry it is a bit rough! I quite dislike using it myself so probably should rewrite it
On a related note, I also use a UPS as the area my server is located has power grid issues (like in many other places outside of cities, even in relatively affluent countries). I was surprised how easy it was to set up the UPS. Maybe at some point a post how to do deep learning when faced with infrastructure issues might be a good idea, if there is interest.
class CsvLogger(Callback):
def __init__(self, path, optim=None, mode='a'):
self.path = path
self.optim = optim
if mode == 'a':
self.df = pd.read_csv(path)
elif mode == 'w':
self.df = pd.DataFrame()
self.rows = []
def on_batch_end(self, batch_loss):
self.rows.append(self.get_row([batch_loss]))
def on_epoch_end(self, metrics):
self.rows.append(self.get_row(metrics, 'epoch'))
self.df = self.df.append(self.rows)
self.df.to_csv(self.path)
self.rows = []
def plot_lr(self):
ax = self.df[self.df['type'] == 'batch'].reset_index()['lr'].plot()
ax.set(xlabel='batch', ylabel='lr')
def plot_metric(self, metric='trn_loss', callback_type='epoch'):
ax = self.df[self.df['type'] == callback_type].reset_index()[metric].plot()
ax.set(xlabel=callback_type, ylabel=metric)
def get_row(self, metrics, row_for='batch'):
if hasattr(self.optim, 'get_lr'):
lr = (self.optim.get_lr())
else:
lr = self.optim.param_groups[0]['lr']
row = {'type': row_for, 'lr': lr, 'trn_loss': metrics[0]}
if len(metrics) > 1:
row['val_loss'] = metrics[1]
for i, m in enumerate(metrics[2:]):
row[f'm{i}'] = m
return row
Callback for saving the model:
class SaveBest(Callback):
def __init__(self, model, path, save_on_train_loss=True):
self.model = model
self.path = path
self.save_on_train_loss = save_on_train_loss
self.best_val_loss = np.inf
self.best_trn_loss = np.inf
def on_epoch_end(self, metrics):
if self.save_on_train_loss and self.best_trn_loss > metrics[0]:
self.best_trn_loss = metrics[0]
self.model.save(f'{self.path}_trn.m')
if len(metrics) > 1 and self.best_val_loss > metrics[1]:
self.best_val_loss = metrics[1]
self.model.save(f'{self.path}_val.m')