Sharing information between x and y transforms using fastai2

I am implementing Epoching’s Blog using fastai2. In this post I am focusing in getting a dataset (DataLoaders) with randomly rotated images.

I need to take an image, randomly rotate it [choose from a fixed list of values], write the rotation angle to the label, feed it to a learner. For the validation set, always use the same rotation for each image. For the training, it would be best that the rotation angle varies from one iteration to another. Show_batch should work too!

So I need to share information between the transform on y (the angle of rotation) and x (to actually rotate the image with that angle). Also, the behavior on validation set is different.

I went from top to bottom wrt to the API level.

I took 3 approaches, each one with its drawbacks.

Take 1

High level DataBlocks. Gets the same label as the rotation, but there is no variation from batch to batch. Of course, validation set is fixed. A lot of the fastai2 magic is kept in place. We use only callbacks.

+ Two callbacks and two lines of code to build the actual DataLoaders
- Does not do the per-iteration data augmentation

One idea to expand this is to have some state object that will append a seed to the file name.
And in the learning process, on iteration end, write some callback that will change the seed. But I do not like it, too much coupling between the data preprocessing and the learning phase!

# Speculate that a file name is the same both when entering the X pipeline
# and Y pipeline. So we can deduce the angle from it.

rot_angles = [0, 45, 90, 135, 180, 225, 270, 315]

def rot_label(img_fname):
    # Gets the actual angle of rotation for this image
    # The angle is constant for the same file. (eg the hash doesn't change
    # during the execution)
    angle = rot_angles[hash(img_fname.name) % len(rot_angles)]
    return angle

def rot_image(img_fname):
    # Opens the image, rotates it and returns a PILImageBW
    # One image will have the same rotation angle regardless in what set
    # is or if it is the n'th time it is read.
    angle = rot_label(img_fname)
    img = load_image(img_fname)
    img = img.rotate(angle)
    img = img.resize((28,28))
    img = PILImageBW(img)
    return img

mnist_rot_block = DataBlock(
    (ImageBlock(cls=PILImageBW), CategoryBlock),
    get_items=get_image_files,
    get_x = rot_image,
    get_y = rot_label,
    splitter = RandomSplitter(valid_pct=0.2))

# Image list reading, batching, toTensor(), int2float, scaling, categorize the rotation angles,
# everything is taken care for, by the DataBlock, because we specified the block types.
mnist_dls_mid_level_v1 = mnist_rot_block.dataloaders(data_path / "training", bs=16)

print(mnist_dls_mid_level_v1.one_batch())
mnist_dls_mid_level_v1.show_batch()
print(f"Train size: {len(mnist_dls_mid_level_v1.train_ds)}, validation size: {len(mnist_dls_mid_level_v1.valid_ds)}")

Take 2
(my) first dip in mid-level API. I exploited the fact that ItemTransform and TfmdLists are meant for the situations when there is one entity that contains both the x and y sample. All “transformations” are applied to this tuple. And there is one unitary dataset built.

+ I get the data augmentation (~ 0.1-0.2 jump in final digit classification task!)
± Some of magic have to be handled by hand.
- I don’t get validation set “stability”

# Version 2. Based on the idea that if the pipeline generates one tuple, use TfmdLists.

# Maybe include this in the library?
class TitledImage(Tuple):
    def show(self, ctx=None, **kwargs): show_titled_image(self, ctx=ctx, **kwargs)
        
class RotationTfm(ItemTransform):
    angles = [0, 45, 90, 135, 180, 225, 270, 315]
    
    def __init__(self):
        super()
        self.c = len(RotationTfm.angles) # The c is the number of classes. Needed when computing various metrics. 

    def encodes(self, img):
        # Encodes takes an image, opened at a previous step in the pipeline
        # Encodes will generate a new rotation for the same image. Even if it is in the validation set.
        angle_idx = np.random.randint(len(RotationTfm.angles))
        img = PILImageBW(img.rotate(self.angles[angle_idx]))
        return (img, torch.Tensor([angle_idx]).long().squeeze()) # label MUST be tensor and squeezed so the collation will get a vector out
    
    def decodes(self, obj):
        return TitledImage(obj[0], obj[1].item())

# Must handle the file list reading by hand
img_list = get_image_files(data_path/"training")
# One have to manually specify all the steps to get a collate-able tensor.    
loading_pipeline = Pipeline([PILImageBW.create, RotationTfm(), ToTensor(), IntToFloatTensor(), Resize(28)])
tfmd_lists = TfmdLists(img_list, loading_pipeline, splits=RandomSplitter(0.2)(img_list))
mnist_dls_mid_level_v2 = tfmd_lists.dataloaders(bs=16)

print(mnist_dls_mid_level_v2.one_batch())
mnist_dls_mid_level.show_batch()
print(f"Train size: {len(mnist_dls_mid_level_v2.train_ds)}, validation size: {len(mnist_dls_mid_level_v2.valid_ds)}")

Take 3

Went deeper, to the Siameze advanced tutorial. If one needs different behavior for the two train/val sets, well, create your own! Like in v2, we treat (x,y) pair as a standalone tuple.

++ Gets the job done!
+ OOP, no callbacks
- NO magic at all. All the ops must be specified! (incl .cuda()). Incl train/val splitting.
- Verbose, it was the hardest to develop, jumped from error to error (well, working on v2 also helped)
- Felt like hacking, sending some irrelevant IDs as samples, to the DataLoaders

# Version 3 Distinct behavior for train and validation set, as shown in Siameze tutorial. 
# That is: http://dev.fast.ai/tutorial.siamese#Using-the-mid-level-API
# There, the items that "walk" into the DataLoaders are just indexes and both the fastai objects and our custom transform will get 
# the same original file list and other info (like splitting.)

class TitledImage(Tuple):
    def show(self, ctx=None, **kwargs): show_titled_image(self, ctx=ctx, **kwargs)
        
class RotationTfm(ItemTransform):
    angles = [0, 45, 90, 135, 180, 225, 270, 315]
    
    def __init__(self, files, is_valid=False):
        super()
        self.files = files
        self.is_valid = is_valid
        self.c = len(RotationTfm.angles) # The "c" is the number of classes. Needed when computing various metrics. 
        # We fix the validation set
        if is_valid:
            self.labels = np.random.randint(0, high=self.c, size=len(files))

    def encodes(self, img_id):
        if self.is_valid:
            angle_idx = self.labels[img_id]
        else:
            angle_idx = np.random.randint(self.c)
        img_name = self.files[img_id]
        img = PIL.Image.open(img_name).convert('L') # Grayscale image
        img = img.rotate(self.angles[angle_idx])
        img = PILImageBW(img) # So we can collate the batch and be able to display the image
        img = ToTensor()(img)
        label = torch.Tensor([angle_idx]).long().squeeze()  # label MUST be tensor and squeezed so the collation will get a vector out
        return (img, label)
    
    def decodes(self, obj):
        return TitledImage(obj[0], obj[1].item())
    
# Must handle the file list reading and splitting by hand
img_list = get_image_files(data_path/"training")
train_idx, val_idx = RandomSplitter(valid_pct=0.2)(img_list)
train_file_list = img_list[train_idx]
val_file_list = img_list[val_idx]

# Construct two TransformedLists idependently. Note the duplicate code. Ugh . . .
tfmd_lists_train = TfmdLists(range(len(train_file_list)), RotationTfm(train_file_list))
tfmd_lists_val = TfmdLists(range(len(val_file_list)), RotationTfm(val_file_list, is_valid=True))

mnist_dls_mid_level_v3 = DataLoaders.from_dsets(tfmd_lists_train,tfmd_lists_val, bs=16,
                                               after_batch=[IntToFloatTensor(), Resize(28)])
mnist_dls_mid_level_v3 = mnist_dls_mid_level_v3.cuda() # You won't get CUDA without this line! Why? Standard behavior for from_dsets()?

print(mnist_dls_mid_level_v3.one_batch())
mnist_dls_mid_level_v3.show_batch()
print(f"Train size: {len(mnist_dls_mid_level_v3.train_ds)}, validation size: {len(mnist_dls_mid_level_v3.valid_ds)}")

Take 4
Hack the ItemTransform.__call__() and read the split_idx? Well, soon I will end up with the original solution, where all the operations were handled inside the custom made dataset.

Discussions

I like the how neat the ver1 is! Everything that we don’t care is handled by fastai2 and we code only what we care about (rotate and label the image with that angle). However, the image reading have to be done by us. We rely on a hack (hashing the image name) . We don’t get the job done. So I moved deeper.

Version 2 is maybe just a stepping stone, showing that the performance gain was worth the effor but it is still not enough to get all the features. Moving to version 3, I realized that the API assumes that one would rarely want some different behavior between validation and training set, except the ON/OFF one. (eg split_idx, thank you @farid for pointing that out!).

So, any thoughts on how I could improve on v2? Or augment v1 without hacks?

I think I will move along with v3. The hardest part there was to nail all the magic down. From setting the proper number of classes (.c property), making the data inside the batch look the same as the data from v1 (that is, squeezing the labels, calling ToTensor()) to making the .cuda() call explicitly, it was a lot of trial and errors. But it is a one-time-only effort and it gets the job done!

Hope it helps!
love to hear some thoughts and suggestions!

Afterthought: Is it possible to apply the ItemTransfrom to a whole batch? That is, inside after_batch pipeline? Put some irrelevant labels in the original dataset and rotate the images once they are batched. However one would still need to override the __call()__ function i guess.

2 Likes

Yes it is possible. The name is perhaps mislieading but ItemTransform means it won’t be applied over each part of the tuple, but take the tuple as whole. This works for items of your dataset as well as batches.

1 Like

version 4

Ok, This is it, I guess. fastai2 is both functional and OOP library so why not exploit a bit the split_idx?
It is a property set at class level so to “change” it we need inheritance. Also, the behavior when split_idx==1 is different. Obvious choices!

Here, I also exploited the fact that Categorize() and image reading are applied BEFORE our custom transform. So show_batch() works almost out of the box with actual angles displayed in titles.

In short: I created two classes, one base class that, for each image, generates a random angle and applies it to the PILImage. The other, inherited from the first, change this behavior by applying directly the original angle. Initial labels are just a set of random angles (not indexes!)

+ It does the job!
+ OOP
+ Magic is back!

Imho this is the best way! I can’t see any room for improvements! (or at least, I am satisfied!)

p.s. I did like 5 mins googling and I couldn’t find a rotation function that takes tensors on cuda. Only PIL images or numpy arrays. So no angle processing at batch level.

p.p.s. The performance is off the charts (like .61 with v4, when v3 is ~.32). It might be a random fluke though . . .

# Version 4

class TitledImage(Tuple):
    def show(self, ctx=None, **kwargs): show_titled_image(self, ctx=ctx, **kwargs)
        
def get_some_random_id(nr_options):
    # just get a number between 0 and nr_options, open interval
    return np.random.randint(0, high=nr_options)
        
def random_label_y(options):
    def fn(item):
        return options[get_some_random_id(len(options))]
    return fn


class RotationTfm_train(ItemTransform):
    split_idx = 0 # Apply it only to training data
    
    def __init__(self,angles):
        self.angles=angles
        self.no_angles = len(angles) #minor speedup
    
    def encodes(self, input_tuple):
        # Decodes a tuple in image and label index, generates a "new" label index if that's the case
        # and then rotates the image
        # Here, the Categorize() has run and we work with angle_id-s not with actual angles.
        img, angle_idx = input_tuple
        rot_angle, angle_idx = self._get_rotation(angle_idx)
        img = img.rotate(rot_angle)
        img = PILImageBW(img)
        return img, angle_idx
        
    def _get_rotation(self, angle_idx):
        # Contains the logic of generating a new label/angle.
        # Overriden in the derivate classes
        new_angle_idx = get_some_random_id(self.no_angles)
        new_angle = self.angles[new_angle_idx]
        # print("_get_rotation from TRAIN")  # Just a check
        return new_angle, new_angle_idx
    
    def decodes(self, obj):
        # We delegate to show_titled_image. Note that obj[1].item() is an angle index and will be decoded
        # by Categorize(). So, the show_batch() will work nice!
        return TitledImage(obj[0], obj[1].item())


class RotationTfm_test(RotationTfm_train):
    split_idx = 1 # Apply it only to validation data

    def _get_rotation(self, angle_idx):
        # Keeps the index the same, and gets the actual angle out
        angle = self.angles[angle_idx]
        # print("_get_rotation from TEST")   # Just a check
        return angle, angle_idx
    
angles = [0, 45, 90, 135, 180, 225, 270, 315]

rot_tfm_train = RotationTfm_train(angles)
rot_tfm_test = RotationTfm_test(angles)

mnist_rot_block_adv = DataBlock(
    (ImageBlock(cls=PILImageBW), CategoryBlock),
    get_items=get_image_files,
    get_y = random_label_y(angles),
    item_tfms = [rot_tfm_train, rot_tfm_test],
    splitter = RandomSplitter(valid_pct=0.2))

# mnist_rot_block_adv.summary(data_path / "training")

# Image list reading, batching, toTensor(), int2float, scaling, categorize the rotation angles,
# everything is taken care for, by the DataBlock, because we specified the block types.
mnist_dls_mid_level_v4 = mnist_rot_block_adv.dataloaders(data_path / "training", bs=16)

print(mnist_dls_mid_level_v4.one_batch())
print(mnist_dls_mid_level_v4.valid.one_batch())
mnist_dls_mid_level_v4.show_batch()
print(f"Train size: {len(mnist_dls_mid_level_v4.train_ds)}, validation size: {len(mnist_dls_mid_level_v4.valid_ds)}")

Note that if you subsclass RandTransform, you get access to the split_idx in before_call. See RandomCrop for an example of use.

1 Like

Thanks! Here, fixed:

# Version 4a

class TitledImage(Tuple):
    def show(self, ctx=None, **kwargs): show_titled_image(self, ctx=ctx, **kwargs)
        
def get_some_random_id(nr_options):
    # just get a number between 0 and nr_options, open interval
    return np.random.randint(0, high=nr_options)
        
def random_label_y(options):
    def fn(item):
        return options[get_some_random_id(len(options))]
    return fn

class RotationTfm(RandTransform, ItemTransform):
    split_idx = None # Apply it to all data
    
    def __init__(self,angles):
        self.angles=angles
        self.no_angles = len(angles) #minor speedup
        self._get_rotation = self._get_rotation_random
    
    def before_call(self, b, split_idx):
        if split_idx == 0:
            # We are on the training set
            self._get_rotation = self._get_rotation_random
        else:
            # We run on validation data
            self._get_rotation = self._get_rotation_fixed
                
    def encodes(self, input_tuple):
        # Decodes a tuple in image and label index, generates a "new" label index if that's the case
        # and then rotates the image
        # Here, the Categorize() has run and we work with angle_id-s not with actual angles.
        img, angle_idx = input_tuple
        rot_angle, angle_idx = self._get_rotation(angle_idx)
        img = img.rotate(rot_angle)
        img = PILImageBW(img)
        return img, angle_idx
        
    def _get_rotation_random(self, angle_idx):
        # Contains the logic of generating a new label/angle.
        # Overriden in the derivate classes
        new_angle_idx = get_some_random_id(self.no_angles)
        new_angle = self.angles[new_angle_idx]
#         print("_get_rotation from TRAIN")  # Just a check
        return new_angle, new_angle_idx

    def _get_rotation_fixed(self, angle_idx):
        # Keeps the index the same, and gets the actual angle out
        angle = self.angles[angle_idx]
#         print("_get_rotation from TEST")   # Just a check
        return angle, angle_idx
    
    def decodes(self, obj):
        # We delegate to show_titled_image. Note that obj[1].item() is an angle index and will be decoded
        # by Categorize(). So, the show_batch() will work nice!
        return TitledImage(obj[0], obj[1].item())

angles = [0, 45, 90, 135, 180, 225, 270, 315]

rot_tfm = RotationTfm(angles)

mnist_rot_block_adv = DataBlock(
    (ImageBlock(cls=PILImageBW), CategoryBlock),
    get_items=get_image_files,
    get_y = random_label_y(angles),
    item_tfms = [rot_tfm],
    splitter = RandomSplitter(valid_pct=0.2))

# mnist_rot_block_adv.summary(data_path / "training")

# Image list reading, batching, toTensor(), int2float, scaling, categorize the rotation angles,
# everything is taken care for, by the DataBlock, because we specified the block types.
mnist_dls_mid_level_v4a = mnist_rot_block_adv.dataloaders(data_path / "training", bs=16)

print(mnist_dls_mid_level_v4a.one_batch())
print(mnist_dls_mid_level_v4a.valid.one_batch())
mnist_dls_mid_level_v4a.show_batch()
print(f"Train size: {len(mnist_dls_mid_level_v4a.train_ds)}, validation size: {len(mnist_dls_mid_level_v4a.valid_ds)}")

Looks pretty neat! I got a ~15% performance (running speed) penalty. but 2 lazy to profile. Maybe because of before_call?

Not sure about the future proof of the double inheritance … Also 2 lazy to google but afaik in python, the methods in the last class written in the inheritance list overrides the methods in previous classes.

2 Likes