Why does kaiming use a**2 instead of abs(a) in the formula?

The formula for kaiming normal is

(2 / ((1 + a ** 2) * fan_in))) ** .5

I understand the reasoning on why ‘a’ is needed for the leaky relu case, but why not just keep it abs(a) instead of a ** 2?

Seems to me like a ** 2 would be less accurate in scaling the weights to give a standard deviation of about 1, after the pass through the weights and abs(a) will be more accurate.

Am I missing something?

I kind of think of (1+ a ** 2) as the proportion of the variance coming from the positive-half (1) and negative-half (a**2) of the RELU.

When no RELU is used, a = 1, making no change to the proportion of variance coming from the negative half, so the "total variance" remains at 1 + 1**2 = 2.

When one uses the normal relu where a = 0, the proportion of the variance coming from the negative-half is zero, giving a "total variance" of 1+0**2 = 1.

When a=0.5, all negative values are halved. The proportion of variance coming from the negative-half becomes 0.5*0.5 = 0.25, making the "total variance" 1+ 0.5**2 = 1.25.

I suspect using abs(a) will not give you the right variance proportion as abs() is a linear function, while variance is a square function.