Unexpected behavior while storing data using a Callback

I was looking for a solution for storing the evolving weight values during training (see here). I got help and it works. But I have stumbled upon an unexpected behavior when I try to access the saved data.
Here is my code

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)

@dataclass
class MyCallback(Callback):
    def __init__(self, learn:Learner):
        super().__init__()
        self.imparo = learn
    
    def on_train_begin(self, **kwargs):
        self.imparo.WholeStuff = []
    
    def on_batch_begin(self, num_batch, **kwargs):
        fileName = f'file_{str(num_batch)}.pt'
        temp = self.imparo.model[1][0].weight.data
        torch.save(temp, fileName)
#        self.imparo.WholeStuff.append(torch.load(fileName))
        self.imparo.WholeStuff.append(temp)

learn = Learner(data, simple_cnn((3,16,16,2)), callback_fns=MyCallback)
learn.fit(1, 0.1)

Now I have access to my stored weights via:

learn.WholeStuff[0]

Disappointingly, this gives exactly the same weight tensor of

learn.WholeStuff[200]

That means that the training (after 200 batches) has no effect on the weight tensor at all! But I do see an improvement of the loss during training.
The problem is solved if I use the slightly changed version (see below). Now, I save the weights on the disk, I load them again, I append them to the attribute WholeStuff. And now it works! Now, I can see the effect of the training on the weights: they change from batch to batch, as expected.

@dataclass
class MyCallback(Callback):
    def __init__(self, learn:Learner):
        super().__init__()
        self.imparo = learn
    
    def on_train_begin(self, **kwargs):
        self.imparo.WholeStuff = []
    
    def on_batch_begin(self, num_batch, **kwargs):
        fileName = f'file_{str(num_batch)}.pt'
        temp = self.imparo.model[1][0].weight.data
        torch.save(temp, fileName)
        self.imparo.WholeStuff.append(torch.load(fileName))   # I exchanged the comments
#        self.imparo.WholeStuff.append(temp)                  #  only on these two lines

It seems that the operation of appending directly the tensor to self.imparo.WholeStuff is somehow wrong. If I take a detour and I use the weights coming from the hard disk, everything works as expected.
I have the feeling I am overlooking something.
Can you please help me to understand this behavior?
Thanks!

I believe that it is a problem of reference versus value: in the first case you are storing a reference to the weight and, as such, when the weights are updated it starts pointing to the new updated weights.
Meanwhile saving on disk insures that you have a value and not a reference.

Using something like temp = temp.clone().detach() should solve the problem.

(I have not tested checked wether it is truly the source of your bug but you have all the usual symptoms)

1 Like

It works!
With some hindsight, it is very similar to deep and shallow copy in Python…

Thanks a lot, nestorDemeure!

Its the same bug :slight_smile: (and a common one in python…)