Fix batch of size 1 raises exception during batchnorm in training

@jeremy Continuing our conversation from

For certain unlucky combinations of batch size (bs) and the length of the training set it can happen that training fails with the exception shown in this notebook (the notebook reproduces the error reliably, just make sure you run it from start to finish after restarting the kernel). The cause of the error is a batch of size 1 being fed into the network which doesn’t work with the batchnorm layers.

The training dataset has 65 elements and bs is set to 64. In that case one batch will be fed into the network with size 64 and the next (last) batch will be fed with size 1.

This error can appear (almost) randomly since the dataset is randomly split into train and validation set and the number in each set isn’t necessarily the same between subsequent instantiations of ImageDataBunch. However one is more likely to run into this error when one uses small batch sizes.

Steps towards a solution:
This pull request fixes this issue by removing the last element from the training set if the combination of bs and length of the training set would result in a bath of size 1. Of course removing a training sample isn’t necessarily the best solution however other solutions that I considered also had downsides. In particular I considered:

  • Passing bs into the random_split method and rerunning the random split until len(is_train) % bs != 1. This has the downside of being a brute force method and bs isn’t even available when random_split is called (the default bs is set in ImageDataBunch.create. One could specify the default value for the bs earlier (e.g. in the from_lists method, but there are other methods that are calling ImageDataBunch.create and the default value of bs wouldn’t be defined in a central location anymore. One could make bs a class variable, but the current library design seems to minimize the use of class variables whenever possible.

  • One could be more strategic about doing the train and validation split. I.e. as a first step partition an array of len(arrs[0]) such that the train size satisfies certain criteria (at the very least bs > 1, but maybe even bs > some number N). But this again raises the need of changing where the default value of bs is set (see above).

@jeremy I know you also requested some tests to show the issue. If you want I can write some, but that would require me to add a new dataset to the test datasets (since the mnist_tiny dataset already has a train and validation folder). Thoughts?

1 Like

Many thanks for this ! Tests would be great - there’s no need to add any data to the repo. Simply generate a dataset from (e.g.) range(n) or similar. Let us know if you need any help with this.

Hmm, what exactly do you want me to write tests for? That the data generator for image data never returns a batch with size of 1? That I could do with range(n). But if you want me to write an integration test whether the learner is able to train that I wouldn’t know how to do with a simple list like range(n).

Other people have run into this issue as well. There is an issue on the pycharm issue tracker:

With the pycharm folks saying that it should fail. However I don’t see the reason. Using a running average should prevent the problem with the zero from appearing in all but the first batch. Also this check is only enforced for BatchNorm1D, BatchNorm2D works fine (or at least it doesn’t throw an error, I checked that using the integration tests).

Also v0.7 ran into this issue in the past too:

@jeremy Who is right here? The pytorch developers saying that a batch with size 1 should fail? Or does the running average prevent problems? And how should fastai handle this?

The more I google the more posts I find about this. Understanding code - error Expected more than 1 value per channel when training

This does trip people up every once in a while. It can be especially annoying with the new fastai v1 library since the size of the train and validation set varies between calls to random_split which makes this error unpredictable.

Jeremy I am happy to implement whatever solution you think is best, but I believe this does require a design choice. Shall we throw an error explaining what is going on and let the user handle it? Or shall we silently remove the offending batch?

Specifically, what I’d need is a test that fails due to this issue. Then I’ll try to make it pass! :slight_smile: The test should run quickly.

ok I created a test that fails due to the batch size of 1 issue: