Calculating mean and std of images in a dataset in Fast AI

Hi All,

Is there a better way to compute the mean and standard deviation of all images in the dataset using some Fast AI convenient functions?

I am trying the following approach:

images = torch.stack([tensor( for o in])
mean = images.mean(axis=[0,1,2])

But it is incredibly slow.

I will appreciate your guidance.

Thanks and Kind Regards,