Batchnorm in reinforcement learning


I’ve been working on an episodic policy learning problem using linear models as the agent, and I’ve noticed that for this particular problem, there is an odd dependence on batch normalization that I hope someone may recognize here. Basically, in order to see any convergence to non-trivial solutions for the reinforcement learning problem, you have to have batchnorm train mode enabled with an extremely high (>.9) momentum (basically, you need to not average too many batch statistics together). However, if you switch to eval mode on the train set right after training with train mode, it usually fails entirely. Rather, you have to run train mode without any learning for a few batches first before switching to eval mode (after training with learning), and then everything works fine. If you try to train your model purely on eval mode the entire time (without ever enabling train mode), the model fails to converge to a non-trivial solution.

My interpretation of this (which is based purely on speculation) is that for whatever reason the rewards are causing great updates to the batch statistics in way that averaging them out would be bad. I don’t really know how to test this hypothesis however, and I’m not sure what the solution is to it. Any help or thoughts would be greatly appreciated.

Hello! This is an interesting problem - I’m not sure how to solve it but would definitely like to discuss and brainstorm with you here about approaches to it.

First, can we get some data / graphs / numbers to help assess the problem? For example showing the loss function, agent performance, etc over time with each batch, to see how the numbers are changing between train --> eval transition?

Second, I think a phone call could go a long way to grok the problem (especially for me as a new deep learning student). If you’re open to it I’d love to chat about your project and this specific problem. :-]

Hi Charlie, thanks for the interest. It turned out that there were a few outliers that caused great shifts in the running mean/var, and so a high momentum term meant that most of the time upon inspection, the batchnorm running update would not be negatively impacted by said outliers. With the low momentum, the outliers were averaged into the running stats with greater weight. Sorry for the delay.