Saving and reloading model to preserve all parameters and point

Hi friends,
using code from lesson 11, I’m trying to save a specific checkpoint of my model and later reload it to continue from the same point, so I do:

torch.save({
‘model_state_dict’: learn.model.state_dict(),
‘optimizer_state_dict’: learn.opt.state,
}, f"{root_path}{name}-model.pkl")

to save the state of the fast.ai stateful optimizer plus model

and then reload with:

checkpoint = torch.load(f"{root_path}{name}-model.pkl")
learn.model.load_state_dict(checkpoint[‘model_state_dict’])
learn.opt.state=checkpoint[‘optimizer_state_dict’]

however, I don’t seem to be able to start right where I was before, could this be because of the stochastic nature of these processes maybe, or am I doing something wrong with the saving and loading?

Pd: just made it work, I was saving a range of these dictionaries and had to use deep copy to preserve each version, I can now reload the data consistently

any tips? :slight_smile:

Hey @javismiles

How did you deepcopy this? I am doing the same - saving state_dict and opt.state, but the validation loss on loading the model is greater than loss before saving!

1 Like

@shruti_01, yes Im doing this:

copy.deepcopy(learn.model.state_dict()
copy.deepcopy(learn.opt.state)

and works all well, try it out :wink:

2 Likes

I did this
torch.save({ 'model_state_dict': copy.deepcopy(learn.model.state_dict()), 'optimizer_state_dict': copy.deepcopy(learn.opt.state), }, f"{root_path}{name}-model.pkl")

To load
checkpoint = torch.load(f"{root_path}{name}-model.pkl") learn.model.load_state_dict(checkpoint['model_state_dict']) learn.opt.state=checkpoint['optimizer_state_dict']

validation loss just before saving is 5.42
After loading and training the model for 1 epoch, loss is 5.47

@shruti_01
after you load, it’s very important, if you want to get the exact same results, that you use the very same data, very same batch size, and batch content and distribution,
if there is any difference in the data being processed the loss values will come out different,
so I would check and make sure that batch content is the same before and after
best :wink:

@javismiles

Umm I am performing exactly same steps for training. While loading the model, instead of directly doing learn.fit I first learn.load, set a new lr and n_epochs and then do learn.fit. Everything else remains the same - batch_size, data - loading a pickled databunch.

[EDIT] In fact, if I train the model for say 10 epochs, and then resume the training after setting a different lr I get different accuracy/valid - and sometimes the validation loss increases for good 10 epochs after decreasing again. Here I am not even loading the model, just resume the training!

@shruti_01

check how the mini batches are being sampled, if the sampling has randomness involved and changes between runs then you would expect small deviations in the loss,

one thing you can do to make sure is this:

After finishing training, sample one specific batch like this:
with torch.no_grad():
txb,tyb = get_batch(data.train_dl)

then save it:
torch.save(txb, f"{root_path}{name}-xb.pkl")
torch.save(tyb, f"{root_path}{name}-yb.pkl"

perform inference on that batch, and see what loss you get.

with torch.no_grad():
cost =loss_func(learn.model(txb),tyb)

Then reload the whole thing from zero, but also reloading that batch

txb=torch.load(f"{root_path}{name}-xb.pkl")
tyb=torch.load(f"{root_path}{name}-yb.pkl")

then the loss should be exactly the same if you perform inference on it

with torch.no_grad():
cost =loss_func(learn.model(txb),tyb)

Got it. Will try when I get to inference.

However, right now I am just still training the model, no inference that is.

hey @shruti_01,
when we train we also validate how the training is going by running the model through our validation set,
so that´s actually what I wanted to say, checking one of your batches in eval mode

so just get a consistent batch, save it, run it through your model in eval mode, then reload the whole thing including that batch and run it again through your model in eval mode and see if both match

1 Like