Unexpected behavior while storing data using a Callback

I was looking for a solution for storing the evolving weight values during training (see here). I got help and it works. But I have stumbled upon an unexpected behavior when I try to access the saved data.
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)

class MyCallback(Callback):
    def __init__(self, learn:Learner):
        self.imparo = learn
    def on_train_begin(self, **kwargs):
        self.imparo.WholeStuff = []
    def on_batch_begin(self, num_batch, **kwargs):
        fileName = f'file_{str(num_batch)}.pt'
        temp = self.imparo.model[1][0].weight.data
        torch.save(temp, fileName)
#        self.imparo.WholeStuff.append(torch.load(fileName))

learn = Learner(data, simple_cnn((3,16,16,2)), callback_fns=MyCallback)
learn.fit(1, 0.1)

Disappointingly, this gives exactly the same weight tensor of


That means that the training (after 200 batches) has no effect on the weight tensor at all! But I do see an improvement of the loss during training.
The problem is solved if I use the slightly changed version (see below). Now, I save the weights on the disk, I load them again, I append them to the attribute WholeStuff. And now it works! Now, I can see the effect of the training on the weights: they change from batch to batch, as expected.

class MyCallback(Callback):
    def __init__(self, learn:Learner):
        self.imparo = learn
    def on_train_begin(self, **kwargs):
        self.imparo.WholeStuff = []
    def on_batch_begin(self, num_batch, **kwargs):
        fileName = f'file_{str(num_batch)}.pt'
        temp = self.imparo.model[1][0].weight.data
        torch.save(temp, fileName)
        self.imparo.WholeStuff.append(torch.load(fileName))   # I exchanged the comments
#        self.imparo.WholeStuff.append(temp)                  #  only on these two lines

It seems that the operation of appending directly the tensor to self.imparo.WholeStuff is somehow wrong. If I take a detour and I use the weights coming from the hard disk, everything works as expected.
I have the feeling I am overlooking something.
Can you please help me to understand this behavior?

I believe that it is a problem of reference versus value: in the first case you are storing a reference to the weight and, as such, when the weights are updated it starts pointing to the new updated weights.
Meanwhile saving on disk insures that you have a value and not a reference.

Using something like temp = temp.clone().detach() should solve the problem.

(I have not tested checked wether it is truly the source of your bug but you have all the usual symptoms)

It works!
With some hindsight, it is very similar to deep and shallow copy in Python…

Thanks a lot, nestorDemeure!

Its the same bug :slight_smile: (and a common one in python…)