Image normalization in PyTorch

Hi,

yes. You need to calculate the mean and std in advance. I did it the following way:

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor()
])

dataloader = torch.utils.data.DataLoader(*torch_dataset*, batch_size=4096, shuffle=False, num_workers=4)

pop_mean = []
pop_std0 = []
pop_std1 = []
for i, data in enumerate(dataloader, 0):
    # shape (batch_size, 3, height, width)
    numpy_image = data['image'].numpy()
    
    # shape (3,)
    batch_mean = np.mean(numpy_image, axis=(0,2,3))
    batch_std0 = np.std(numpy_image, axis=(0,2,3))
    batch_std1 = np.std(numpy_image, axis=(0,2,3), ddof=1)
    
    pop_mean.append(batch_mean)
    pop_std0.append(batch_std0)
    pop_std1.append(batch_std1)

# shape (num_iterations, 3) -> (mean across 0th axis) -> shape (3,)
pop_mean = np.array(pop_mean).mean(axis=0)
pop_std0 = np.array(pop_std0).mean(axis=0)
pop_std1 = np.array(pop_std1).mean(axis=0)

Note that in theory, the standard deviation of the whole dataset is different than if you calculate the std per minibatch and then calculate the final std as a mean of minibatches’ stds (as I did, try to have the batch size as large as possible, I used 4096). The problem is with a huge dataset like mine (>12 mil images), you can never calculate the standard deviation across the whole dataset due to memory constraints. If your dataset is of reasonable size and you can load the whole thing into memory, then you can calculate both mean and std of the whole thing. But in practise, it shouldn’t be a problem if you use the mean of standard deviations of all the minibatches.

Also note, that it’s calculated on the CPU and not the GPU, so if you run on cloud, you can do it on some cheap instance and you don’t have to use a GPU instance.

Once you have the mean and std, just add the following line to the transforms.Compose list:

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=*your_calculated_mean*, std=*your_calculated_std*)
])

Hope that helps.

8 Likes