How to do augmentations in the Siamese Example

I have been reading the new fastai book on github and came across an example of using a siamese network. However I can’t seem to figure out how to add augmentations into the process.

from import *
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

class SiameseImage(Tuple):
    def show(self, ctx=None, **kwargs): 
        img1,img2,same_breed = self
        if not isinstance(img1, Tensor):
            if img2.size != img1.size: img2 = img2.resize(img1.size)
            t1,t2 = tensor(img1),tensor(img2)
            t1,t2 = t1.permute(2,0,1),t2.permute(2,0,1)
        else: t1,t2 = img1,img2
        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
        return show_image([t1,line,t2], dim=2), 
                          title=same_breed, ctx=ctx)
def label_func(fname):
    return re.match(r'^(.*)_\d+.jpg$',[0]

class SiameseTransform(Transform):
    def __init__(self, files, label_func, splits):
        self.labels =
        self.lbl2files = {l: L(f for f in files if label_func(f) == l) for l in self.labels}
        self.label_func = label_func
        self.valid = {f: self._draw(f) for f in files[splits[1]]}
    def encodes(self, f):
        f2,t = self.valid.get(f, self._draw(f))
        img1,img2 = PILImage.create(f),PILImage.create(f2)
        return SiameseImage(img1, img2, t)
    def _draw(self, f):
        same = random.random() < 0.5
        cls = self.label_func(f)
        if not same: cls = random.choice(L(l for l in self.labels if l != cls)) 
        return random.choice(self.lbl2files[cls]),same
splits = RandomSplitter()(files)
tfm = SiameseTransform(files, label_func, splits)
tls = TfmdLists(files, tfm, splits=splits)
dls = tls.dataloaders(after_item=[Resize(224), ToTensor], 
    after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])

Elsewhere in the book it shows adding augmentations and it looks like this:

bears =, batch_tfms=aug_transforms(mult=2))
dls = bears.dataloaders(path)

How would I add aug_transforms into the siamese example?

I discussed it briefly with Sylvain:


cool. Thanks for the update.

1 Like

No problem :slight_smile: if you run into issues feel free to ping me, I plan on working on what I described here in the next few days

Can I ask what your using Siamese networks for? I presume semi-supervised learning? I’m doing some work on image similarity with Siamese networks and wondered if either of you have done anything with triplet mining?

I was just playing around with various architectures to learn more about them.

I am currently doing this… It used to be a bit easier but TupleTransform was removed recently. I can post what I am doing after lesson1 is over today.

You should be able to do

dls = tls.dataloaders(after_item=[Resize(224), ToTensor], 
    after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats), *aug_transforms()])

aug_transforms returns a list, so you use * to split it our into its components. I can look at it some more if that does not work.