In this post we’ll delve into the math behind the code for std_agg()
, the function that @jeremy created to compute the standard deviation of a one-dimensional data vector, discussed in the Lesson #7 video, starting at the 45:32
point.
In the std_agg()
function, @jeremy employs the following expression for mean squared deviation:
sum((x - <x>)**2)/N = <x**2> - <x>**2
If you’re curious about this basic result from statistics, read on!
The standard deviation is a measure of the spread or dispersion in a set of real values. Mathematically, it is computed as the root mean squared deviation. Don’t worry if you are unfamiliar with these terms; we’ll discuss them below.
The definition of the standard deviation of a vector of values x is
std(x) = sqrt( sum( (x - <x>)**2 )/N )
In this formula,
-
N is the number of samples (x values), and angle brackets <> around a vector x denotes the expectation value, or mean of x: <x> = sum(x)/N.
-
x - <x> is called the deviation of a sample x from the mean.
-
for simplicity, we’ve used N instead of the usual N - 1 in the denominator, which is fine as long as N is large.
Let’s expand the quadratic squared deviation inside the sum in the preceding equation:
(x - <x>)**2 = x**2 - 2*x*<x> + <x>**2.
Summing the squared deviations over the samples and dividing by N.gives the mean squared deviation:
sum((x - <x>)**2)/N = sum( x**2)/N - 2*sum(x*<x>)/N + sum(<x>**2)/N
Let’s examine the right hand side of the preceding equation:
-
The first term is the mean of x**2, which is by definition the expectation value <x**2>
-
In the second term, sum(x*<x>)/N is the same as <x>*sum(x)/N, since <x> is a constant factor and can be taken outside the sum.
-
But look! sum(x)/N is the mean of x which is by definition equal to <x>, the expectation value of x. Therefore sum(x*<x>)/N= <x>*sum(x)/N = <x>*<x> = <x>**2. So the second term becomes -2*<x>**2
-
The third term is just <x>**2, a constant: summing over the samples means just multiplying by N.
Adding together the three terms, we find that the mean squared deviation is
sum((x - <x>)**2)/N = <x**2> - 2<x>**2 + <x>**2 = <x**2> - <x>**2
Note that this is the mean of the squares minus the square of the mean
By its definition, the standard deviation is the root mean squared deviation, so now we just need to take the square root:
std(x) = sqrt(sum((x - <x>)**2)/N) = sqrt(<x**2> - <x>**2)
In @jeremy’s notation, cnt is the number of samples N,
s1 is the sum of the sample values x, so <x> = s1/N = s1/cnt,
s2 is the sum of squares of sample values, so <x**2> = s2/N
The formula for standard deviation becomes
std(x) = sqrt( s2/cnt - (s1/cnt)**2 )
And thus @jeremy’s function definition is
def std_agg(cnt,s1,s2): return math.sqrt( (s2/cnt) - (s1/cnt)**2 )
This formula speeds the computation of the standard deviation: its inputs are the sum of the data, and the sum of the squares of the data, both of which are O(N).