Automatic (involuntary) change of batch size during training

Hi all, while trying to troubleshoot a bug I added

print(len(xb))
print(xb[0].shape)

after line 17 in basic_train.py. If one then executes the resnet34 (frozen) training ([In] 13 in the notebook) the output shows that the batch size goes up all the way to 128. You can scroll down to the bottom of the output to see that, but it also happens in the middle. Why is that?

Here is the output:

1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([63, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([95, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([63, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([95, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([63, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([95, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([64, 3, 224, 224])
1
torch.Size([63, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([128, 3, 224, 224])
1
torch.Size([95, 3, 224, 224])
Total time: 04:08
epoch  train loss  valid loss  error_rate
1      1.136672    0.301131    0.085828    (01:06)
2      0.491456    0.243652    0.073852    (01:02)
3      0.312260    0.235624    0.073187    (01:00)
4      0.230828    0.213188    0.066534    (00:59)

Ahh never mind. Shortly after writing this up it dawned on me. This only happens during inference (the validation part). Seemingly fast.ai is increasing the batch size during inference.

3 Likes

I just noticed this as well.
Here is where it happens: https://github.com/fastai/fastai/blob/3a9d07a00ff46d73d1a84df8cfbb4e5efbda90a8/fastai/vision/data.py#L291
Does anyone know the reason for this?

When you’re doing inference on the validation set, your model isn’t keeping track of gradients (because you’re not training). This frees up a ton of GPU memory. fastai takes advantage of this by doubling the batch size for the validation set. This lets you process the validation dataset that much faster.

1 Like

Thank you, this makes sense!