Normalizing your Dataset

Hi there!
If you are using torchvision.datasets.ImageFolder to prepare your dataset, then the following snippet might help in obtaining the mean and std of the dataset. (got this from pytorch forum)

def online_mean_and_sd(loader):
    """Compute the mean and sd in an online fashion

        Var[x] = E[X^2] - E^2[X]
    """
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for images, _ in loader:

        b, c, h, w = images.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(images, dim=[0, 2, 3])
        sum_of_square = torch.sum(images ** 2, dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)

        cnt += nb_pixels

    return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)

Steps for normalizing your dataset (specific to pytorch DataLoader)

  1. Prepare the data without normalization and get the dataloader :
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
TRAIN_PATH = "./food-101/train/"
bs = 64

train_tfms = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(), 
            transforms.ToTensor() 
           # don't use transforms.Normalize() for the first time
        ])
train_ds = datasets.ImageFolder(root=TRAIN_PATH, transform=train_tfms)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
  1. Obtain the mean & standard deviation of the dataset:
food101_mean, food101_std = online_mean_and_sd(train_dl)
print(food101_mean, food101_std)

(this may take some time as it loops over the whole dataset in minibatches)
output:

(tensor([0.5567, 0.4381, 0.3198]), tensor([0.2591, 0.2623, 0.2633]))
  1. Normalize the dataset by passing these values (mean and std) to transforms.Normalize() :
train_tfms = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(), 
            transforms.ToTensor(),
            transforms.Normalize(food101_mean, food101_std) #normalizing here 
        ])
        
valid_tfms = transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(food101_mean, food101_std)  #normalizing here 

train_ds = datasets.ImageFolder(root=TRAIN_PATH, transform=train_tfms)
valid_ds = datasets.ImageFolder(root=VALID_PATH, transform=valid_tfms)

train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=bs*2, shuffle=False)

  1. Verify your mean and std again by:
print(online_mean_and_sd(train_dl))

output: (mean & std for each channels R,G&B)
(tensor([-6.4650e-04, -3.8490e-04, 1.6878e-05]),
tensor([0.9996, 0.9997, 0.9990]))

Now our dataset has mean=0 and std=1 and there are fewer chances of vanishing or exploding gradients.

N.B: There are so many ways you can normalize your data, but the above is suitable for pytorch DataLoader only.
But,
if you have x_train and x_valid, normalization can be done easily with the following snippets (which Jeremy had taught in the course):

def normalize(x, m, s): return (x-m)/s
def normalize_to(train, valid):
    m,s = train.mean(),train.std()
    return normalize(train, m, s), normalize(valid, m, s)

x_train, x_valid = normalize_to(x_train,x_valid)
6 Likes