Tutorial's SiameseTransform training alwyas picks the second img from valid set?

Greetings,

While checking FastAI’s siamese tutorial, I was trying to understand :

class SiameseTransform(Transform):
def __init__(self, files, splits):
    self.valid = {f: self._draw(f) for f in files[splits[1]]}
    
def encodes(self, f):
    f2,t = self.valid.get(f, self._draw(f))
    img1,img2 = PILImage.create(f),PILImage.create(f2)
    return SiameseImage(img1, img2, int(t))

def _draw(self, f):
    same = random.random() < 0.5
    cls = label_func(f)
    if not same: cls = random.choice(L(l for l in labels if l != cls)) 
    return random.choice(lbl2files[cls]),same

This transform is then passed on to create TfmdLists, and then train the network:

tfm = SiameseTransform(files, splits)
tls = TfmdLists(files, tfm, splits=splits)
dls = tls.dataloaders(after_item=[Resize(224), ToTensor], 
                      after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])

In the tranform’s encodes() function, it only looks for f2 in the validation set (splits[1]).

This means that even during training, the second image would come from the validation set.

Am i misunderstanding something or is this by design? I was thinking that the training set should get it’s own set of image pairs, and the validation set should get it’s own.

Please advise.

Cheers
Gary

Hi Gary,
I think you are misunderstanding the first line of encodes:

f2,t = self.valid.get(f, self._draw(f))

Doing self.valid.get checks if the key f is in the dictionary. If it is, it returns the value for f. If not, it returns the 2nd argument, which is self._draw(f). Here’s another way you could write that same line that may help clarify things:

if f in self.valid: # f is in the validation set
    f2, t = self.valid[f]
else: # f is in the training set
    f2, t = self._draw(f)

So this line is checking if f is in the validation set. If it is, it returns the image and label that f maps to. If f is not in the validation set, it must instead be in the training set. In that case, it calls self._draw(f) and returns a new random image and label.

By doing this, the validation set will always be the same, but during training, each image will get randomly paired up with another image! That way the network always gets to see new image pairs during training, not just the same pairs it’s seen over and over again.

So your thinking is correct. During training the network does only see training images, and during validation it does always see pairs existing of one validation image and another image. One part which has confused me a little is that it seems that the validation examples each consist of 1 validation image paired with 1 image that could be from the validation set or training set.

Let me know if this helps!

Thanks for you reply @GoofyMango,

So the _draw() method does not take splits into account. It would randomly get either a file with the same label or another label, irrespective of the splits of training and validation sets. This means that during training and validation, we may see the second file on the Siamese Pair coming from all of the files(splits[0] and splits[1] taken together), since draw() operates on all of the labels, and then picks it from the lbl2files dict if the second file has to be from a different label.

However, you are correct on the encodes() part of the code, and I did misunderstand it. But upon inspecting it a little more :

  1. In __init__(), we just store a dict of file1:file2,boolean into self.valid
  2. During encodes(), we just check if the file is in the self.valid dict, if not, we create a siamese pair by calling _draw() on it, which will give us another file, which can be from the entire dataset(as mentioned above).
  3. Even if we change the code for __init__ and replace self.valid = {f: self._draw(f) for f in files[splits[1]]} to be self.train = {f: self._draw(f) for f in files[splits[0]]} , it just means that during training, the file will be found in the dict, and during validation, it will be retrieved using _draw(), which is the opposite of what is being done now.
  4. In any case, both training AND validation stages will receive the second file in the Siamese Pair from the whole dataset (randomly of course).

So I agree with you , partially. Since there is no clear separation of training and validation set in the example, and the second file may come from the entire file set, randomly.

I wonder how does this impact the training, given that we usually practice keeping training and validation sets entirely apart.

Let me know what you think.

Cheers…
Happy Friday…!
-Gary

Hi Gary,

Looking into it more, it looks like the mixing of images between the train and validation sets was a mistake and they just merged a fix for it like 6 hours ago!

class SiameseTransform(Transform):
    def __init__(self, files, splits):
        self.splbl2files = [{l: [f for f in files[splits[i]] if label_func(f) == l] for l in labels}
                          for i in range(2)]
        self.valid = {f: self._draw(f,1) for f in files[splits[1]]}
    def encodes(self, f):
        f2,same = self.valid.get(f, self._draw(f,0))
        img1,img2 = PILImage.create(f),PILImage.create(f2)
        return SiameseImage(img1, img2, same)
    
    def _draw(self, f, split=0):
        same = random.random() < 0.5
        cls = label_func(f)
        if not same: cls = random.choice(L(l for l in labels if l != cls)) 
        return random.choice(self.splbl2files[split][cls]),same

The updated code in the tutorial ensures that the train and validation sets stay totally separate. Wooo!

Hi Brandon,

thanks for checking the solution out.

Quote from @ ucohen on the fix:

My suggestion is to setup both dictionaries for the train and valid splits during __init__ :

class SiameseTransform(Transform):
    def __init__(self, files, splits):
        self.lbl2files = [{l: [f for f in files[splits[i]] if label_func(f) == l] for l in labels}
                          for i in range(2)]
        self.valid = {f: self._draw(f,1) for f in files[splits[1]]}

Thanks again, i will mark this post as solved.

Cheers
Gary

1 Like