I believe worrying about batchnorm is important, but mean 0 and std 1 is very important for just getting anything to work, even without batchnorm, and this comes back to proper initialization
Wrote a little bit of code to show this concept:
nums=10000
#normal distributions have a nice properties of operations on them having a simple linear relationships
y=np.random.normal(size=nums)*np.random.normal(size=nums)
print('Ex:%f %f'%(y.mean(),y.std())) #mean=0,std=1
y=np.random.normal(size=nums)+np.random.normal(size=nums)
print('Ex:%f %f'%(y.mean(),y.std())) #mean=0,std=1.4
#Here the std is about 1.4, this is expected as std=sqrt(var); sqrt(var+var)=sqrt(2*var)=sqrt(2)*sqrt(var)
#Sqrt(2) is about equal to 1.4, so we can now adjust this
y=np.random.normal(scale=1/sqrt(2),size=nums)+np.random.normal(scale=1/sqrt(2),size=nums)
print('Ex:%f %f'%(y.mean(),y.std())) #mean=0,std=1.0
#We are now ready to implement matrix multiply!!! When adding we need to have a std of 1/sqrt(n) where n is the number of additions, per column
#We apply the scale factor to only 1 distribution, in order to mimic kaiming/Xavier initialization
#We need to add an extra diminsion here, because matrix multiply sums up our values so 1xn @ nx1 = 1x1, and we want to calculatte mean/std
nums=100 #100
size=(nums,nums)
std=1/sqrt(nums)
y=np.random.normal(size=size)@np.random.normal(scale=1/sqrt(nums),size=size)
print('Ex:%f %f'%(y.mean(),y.std())) #mean=0,std=1.0
print(y.shape)
#Now if we do this for more than one layer
y=np.random.normal(size=size)@np.random.normal(scale=std,size=size)
y=y@np.random.normal(scale=std,size=size)
y=y@np.random.normal(scale=std,size=size)
y=y@np.random.normal(scale=std,size=size)
y=y@np.random.normal(scale=std,size=size)
y=y@np.random.normal(scale=std,size=size)
print('Ex:%f %f'%(y.mean(),y.std())) #mean=0,std=1.0
#Changing the std quickly has consequences
std=2.0
y=np.random.normal(size=size)@np.random.normal(scale=std,size=size)
y=y@np.random.normal(scale=std,size=size)
y=y@np.random.normal(scale=std,size=size)
y=y@np.random.normal(scale=std,size=size)
y=y@np.random.normal(scale=std,size=size)
y=y@np.random.normal(scale=std,size=size)
print('Ex:%f %f'%(y.mean(),y.std())) #mean=variable, std=60000000+
#Changing the mean quickly has consequences
mean=1.0
std=1/sqrt(nums)
y=np.random.normal(size=size)@np.random.normal(loc=mean,scale=std,size=size)
y=y@np.random.normal(loc=mean,scale=std,size=size)
y=y@np.random.normal(loc=mean,scale=std,size=size)
y=y@np.random.normal(loc=mean,scale=std,size=size)
y=y@np.random.normal(loc=mean,scale=std,size=size)
y=y@np.random.normal(loc=mean,scale=std,size=size)
print('Ex:%f %f'%(y.mean(),y.std())) #mean=variable, std=variable
Output of above code, first number is mean, second is std, the (100,100) is the shape of matrix multiply output:
Ex:-0.006755 0.995809
Ex:-0.003384 1.408405
Ex:0.011278 1.001778
Ex:0.007975 1.023256
(100, 100)
Ex:0.008841 1.039414
Ex:243008.276950 60611755.826758
Ex:-2637638925.689581 95386006060.297302
Explanation of results:
Ex:-0.006755 0.995809 - mean,std after N(0,1)*N(0,1)
Ex:-0.003384 1.408405 - mean,std after N(0,1)+N(0,1), std is 1.4
Ex:0.011278 1.001778 - mean,std after N(0,1/sqrt(2))+N(0,1/sqrt(2)), std 1.0 after correction
Ex:0.007975 1.023256 - matrix multiply N(0,1)@N(0,1/sqrt(n)), mean 0, std: 1
(100, 100)
Ex:0.008841 1.039414 - matrix multiply stable even after 6 layers
Ex:243008.276950 60611755.826758 - std not 1/sqrt(n) in weight initialization
Ex:-2637638925.689581 95386006060.297302 - mean not 0 in weight initialization
These numbers have gotten so big they can’t even reasonably be expected to be represented on the gpu, calculations done on CPU.
So here I have removed all of the deep learning reasons for layers to have N(0,1) output. N(0,1) is important to make the std and mean not explode into completely unreasonable values.
Though, all of this only really is important for showing that you should have predictable linear changes from one layer to another. It doesn’t really matter in the above examples if your original data is N(0,1) or N(0.5,2), just that if you start with N(0,1) your output should be N(0,1), or a linear transformation of your input.
(Xavier+Kaiming, are basically there to ,make up for sigmoid/relu activations respectively)
Also, more recent papers suggest batch norm does not help with covariant shift: https://arxiv.org/pdf/1805.11604.pdf
I believe the above paper is referenced in part 2 or part 1 at some point.