Understanding mean((-1 , -2)) in mnist_distance

It might sound like a dumb thing, but I am starting all over again Fastai after 2 months. While going through, I came across the mnist_distance on mnist_basics.ipynb

def mnist_distance(a, b):
  return (a-b).abs().mean((-1 , -2))

Well if check the above image, I have used mean() without the tuple-like thing inside it. But it still gives me the right answer of the mnist_distance function. What change does (-1, -2) brings to the function.

Sorry might sounds dumb lol Thank you for your answer!!

1 Like

The axis parameter can be int or tuple of ints. Those ints represent axes along which the means are computed. The default is to compute the mean of the flattened array (1).

Negative ints are used to specify ‘count from last’. The axis -1 is equal to last axis (which in this case, axis 1) and -2 to one before last (which in this case, axis 0).

In this case, since you are passing 2-dimensional array, you see the same behaviour. If you pass in higher dimensional arrays, you will see the difference:

import numpy as np
foo = np.random.randn(4, 3, 2)
foo.mean(axis=(-1, -2)) # array of 4 numbers
foo.mean() # a single number

(1) numpy.mean


Well you are right.
In the example you are referencing, using mean() and mean((-1,-2) produces the same results. This is because we are checking for the absolute mean distance between two images, a sample 3 and the ideal 3, both which are 28x28 pixels. Therefore using mean((-1, -2)) will get the mean of the last and 2nd last axes, but we only have two axes, that’s why it produces the same results.

But, the reason why mean((-1, -2)) is important becomes evident during the next section, when we use mnist_distance with broadcasting, to calculate the distance to the ideal 3(mean3) for every image in the validation set using:
mnist_distance(valid_3_tens, mean3)

Running (valid_3_tens-mean3).abs().shape returns: torch.Size([1010, 28, 28])

Here, if we ran (valid_3_tens-mean3).abs().mean() it would return just one value, which is not what we want. It would calculate the mean distance of every validation 3 image from the ideal 3, then get the mean of all those distances, returning one value. (Actually, try running it this way and see for yourself)

That why we need to run (valid_3_tens-mean3).abs().mean((-1, -2)) which will return 1010 values (the remaining axis which we are not calculating the mean over) which is exactly what we want. A vector of size 1010, which is the distances of each image in the validation set from the ideal 3.

Hope my explanation is clear. If you need any further clarification I’m here


Thanks, @jimmiemunyi @manju-dev. It makes sense now, so it’s use during the cases of broadcasting at times when we wanna iterate over the complete validation sets.

1 Like

@Ashik_Shafi ,

It is just similar as doing

diff = tensor([(valid3 - mean3).abs().mean() for valid3 in valid_3_tens])