I understand the reasoning on why ‘a’ is needed for the leaky relu case, but why not just keep it abs(a) instead of a ** 2?
Seems to me like a ** 2 would be less accurate in scaling the weights to give a standard deviation of about 1, after the pass through the weights and abs(a) will be more accurate.
I kind of think of (1+ a ** 2) as the proportion of the variance coming from the positive-half (1) and negative-half (a**2) of the RELU.
When no RELU is used, a = 1, making no change to the proportion of variance coming from the negative half, so the "total variance" remains at 1 + 1**2 = 2.
When one uses the normal relu where a = 0, the proportion of the variance coming from the negative-half is zero, giving a "total variance" of 1+0**2 = 1.
When a=0.5, all negative values are halved. The proportion of variance coming from the negative-half becomes 0.5*0.5 = 0.25, making the "total variance" 1+ 0.5**2 = 1.25.
I suspect using abs(a) will not give you the right variance proportion as abs() is a linear function, while variance is a square function.