Adding a callback that shouldn't do anything instead increases validation but not training error, not sure why

Hi there,

I wanted to add a callback to the learner that used the model a few times, and noticed that it was causing the model to get poorer results on the validation set but not the training set. Even when I removed all of the parts of the callback that interacted with the actual training process, this problem remained. Here’s the most pared down version I can create that reproduces the phenomenon:

!pip install -Uqq fastbook
import fastbook
import scipy.stats
fastbook.setup_book()
from fastbook import *

path = untar_data(URLs.IMAGENETTE_320)
Path.BASE_PATH = path
path2 = (path/"train")
img_size = 128
pets = DataBlock(blocks = (ImageBlock, CategoryBlock),
                 get_items=get_image_files, 
                 splitter=GrandparentSplitter(train_name='train', valid_name='val'),
                 get_y=parent_label,
                 item_tfms=Resize(260),
                 batch_tfms=aug_transforms(size=img_size, min_scale=0.75))
dls = pets.dataloaders(path, num_workers=0, bs=32)

#baseline result:
resmodel = resnet18(pretrained=False)
reslearn = Learner(dls, resmodel, loss_func=None, metrics=accuracy,
                   opt_func=Adam)

#this gets a final training loss of 0.71, valid loss of 0.68, and accuracy of 78%

class UselessLoop(Callback):
    def before_batch(self):
        for n in range(10):
            throwaway_grad = self.model(self.x[0].unsqueeze(dim=0))[:, 5]  
#here I'm calling the model on just one example from the batch in order to speed things up; 
#the effect seems the same whether you call it on one example or on the whole minibatch

resmodel = resnet18(pretrained=False)
reslearn = Learner(dls, resmodel, loss_func=None, metrics=accuracy, cbs = [UselessLoop],
                   opt_func=Adam)
reslearn.fit_one_cycle(5, 0.001)

#this gets a final training loss of 0.74, validation loss of 1.57, and accuracy of 60%

The more times you use the model, the worse the generalization error gets - if it only loops once, it’s a minor difference, and if you loop it 100 times, it’s very big. The training loss always seems unaffected or at least not affected by much. What’s going on here? This has to be by far the weirdest problem I’ve run into.

The issue is your UselessLoop callback is still interacting with the training process.

class UselessLoop(Callback):
    def before_batch(self):
        for n in range(10):
            throwaway_grad = self.model(self.x[0].unsqueeze(dim=0))[:, 5]  

Without using with no_grad() as a context manager, your callback is still accumulating gradients.

From the Zeroing out gradients in PyTorch recipe:

torch.Tensor is the central class of PyTorch. When you create a tensor, if you set its attribute .requires_grad as True, the package tracks all operations on it. This happens on subsequent backward passes. The gradient for this tensor will be accumulated into .grad attribute. The accumulation (or sum) of all the gradients is calculated when .backward() is called on the loss tensor.

Use with no_grad and perhaps set the model in eval mode if you want to turn off dropout (don’t forget to set it back to train mode) and it should train like normally even with the UselessLoop callback.

The issue still persists even if I change it to:

class UselessLoop(Callback):
    def before_batch(self):
        for n in range(10):
            with torch.no_grad():
                throwaway_grad = self.model(self.x[0].unsqueeze(dim=0))

And calling backward() and zero_grad() on it like this also doesn’t fix the problem:

class UselessLoop(Callback):
    def before_batch(self):
        if not self.learn.training: return
        for n in range(10):
            throwaway_grad = self.model(self.x[0].unsqueeze(dim=0))[:, 5]
            throwaway_grad.backward()
            self.opt.zero_grad()

So I don’t think that’s the issue.

Edit: Although actually adding .backward() and .zero_grad() did reduce the generalization error quite a bit, although it was still worse than the baseline run. Also calling the model on the whole minibatch instead of just one example from the minibatch seems to fix the problem almost entirely (but not quite I think), even though I said in the OP that it didn’t (I’d thought I’d tested it but I guess I hadn’t).

Sorry, I forgot that BatchNorm updates during the forward pass. So the Callback is still effecting the model, especially since the batchnorm statistics are updating on the same item 10 times per batch. Which is probably the primary problem since training loss is similar.

Setting the model to eval prevents BatchNorm from updating and any gradients from accumulating.

class UselessLoop(Callback):
    run_valid=False
    def before_batch(self):
        self.model.eval()
        for n in range(10):
            a = self.model(self.x[0].unsqueeze(dim=0))[:, 5]
        self.model.train()

Then, where get_dls is a function to create the dataloader, you can run this code and with and without the UselessLoop callback and the results will be unchanged.

with no_random():
    dls = get_dls(size=128, bs=64)
    learn = Learner(dls, xse_resnet18(n_out=dls.c), metrics=accuracy).to_fp16()
    learn.fit_one_cycle(2, 3e-3)

Oh wow, I had no idea batchnorm did that. Thanks, that solves the issue I was having!