Creating Siamese ImageDataBunch Fastai v1

Hello!

I’m trying to use the fastai New Class ImageDataBunch to create a siamese dataset (img1,img2),target, target being similar/dissimilar (0/1). The issue is that the loss function calculate the dimension from the Input dimension in this case (img1,img2). My Siamese network looks like this :

class SiameseNet(nn.Module):
def __init__(self, embedding_net):
    super(SiameseNet, self).__init__()
    self.embedding_net = embedding_net
    self.pdist = nn.PairwiseDistance()
    self.ffc = nn.Linear(4, 2)

def forward(self, x1, x2):
    output1 = self.embedding_net(x1)
    output2 = self.embedding_net(x2)
    out = self.pdist(output1, output2)
    out = F.log_softmax(self.ffc(out),dim=-1)
    return out

def get_embedding(self, x):
    return self.embedding_net(x)

My loss function is N LLoss, but it doesn’t work because the dimension of the input is 1 as you can see on the code above. What are the best practices on training a siamese network (multiple inputs) with fastai v1?

Ideally I would like that the ImageDataBunch to be able to handle multiple inputs and apply transforms to them.

PS: Thank you for your effort with this library and the courses I’ve learned a huge amount!

For a problem like this, you need to build your custom Dataset class. Look at the various datasets in vision.data to get some ideas!

1 Like

Is it possible to use a custom batch sampler than the default one with the library? im trying to generate pairs/triplets within a minibatch as suggested in the facenet paper(online triplets) and hence need to use a custom batch sampler. Currently databunch object isnt taking batch_sampler as a parameter…is there anyway to handle this?

You can either build your DataLoader yourself (then use DataBunch.__init__) or use the function new of a DeviceDataLoader that allows you to re-create the dataloder beneath with new arguments:

data.train_dl = data.train_dl.new(shuffle=False, batch_sampler=...)

(the shuffle=False will be necessary to override shuffle=True in the training dataloader, because pytorch will throw an error otherwise).

1 Like