I was looking for a solution for storing the evolving weight values during training (see here). I got help and it works. But I have stumbled upon an unexpected behavior when I try to access the saved data.
Here is my code
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
@dataclass
class MyCallback(Callback):
def __init__(self, learn:Learner):
super().__init__()
self.imparo = learn
def on_train_begin(self, **kwargs):
self.imparo.WholeStuff = []
def on_batch_begin(self, num_batch, **kwargs):
fileName = f'file_{str(num_batch)}.pt'
temp = self.imparo.model[1][0].weight.data
torch.save(temp, fileName)
# self.imparo.WholeStuff.append(torch.load(fileName))
self.imparo.WholeStuff.append(temp)
learn = Learner(data, simple_cnn((3,16,16,2)), callback_fns=MyCallback)
learn.fit(1, 0.1)
Now I have access to my stored weights via:
learn.WholeStuff[0]
Disappointingly, this gives exactly the same weight tensor of
learn.WholeStuff[200]
That means that the training (after 200 batches) has no effect on the weight tensor at all! But I do see an improvement of the loss during training.
The problem is solved if I use the slightly changed version (see below). Now, I save the weights on the disk, I load them again, I append them to the attribute WholeStuff. And now it works! Now, I can see the effect of the training on the weights: they change from batch to batch, as expected.
@dataclass
class MyCallback(Callback):
def __init__(self, learn:Learner):
super().__init__()
self.imparo = learn
def on_train_begin(self, **kwargs):
self.imparo.WholeStuff = []
def on_batch_begin(self, num_batch, **kwargs):
fileName = f'file_{str(num_batch)}.pt'
temp = self.imparo.model[1][0].weight.data
torch.save(temp, fileName)
self.imparo.WholeStuff.append(torch.load(fileName)) # I exchanged the comments
# self.imparo.WholeStuff.append(temp) # only on these two lines
It seems that the operation of appending directly the tensor to self.imparo.WholeStuff is somehow wrong. If I take a detour and I use the weights coming from the hard disk, everything works as expected.
I have the feeling I am overlooking something.
Can you please help me to understand this behavior?
Thanks!