Loading weights from a trained siamese network to a "single network" for inference

I have trained a siamese network which looks like this:

class SiameseNetwork(nn.Module):
    def __init__(self, arch):
            super().__init__() 
            self.body = create_body(arch)
            self.head = create_head(nf=512*2, nc=8, lin_ftrs=[256])
    
    def forward(self, im_A, im_B):
        x1 = self.body(im_A)
        x1 = self.head(x1)
    
    x2 = self.body(im_B)
    x2 = self.head(x2)
    return F.pairwise_distance(x1, x2) 

After training it, I would like to copy its weights to its equivalent “single network”:

class SingleNetwork(nn.Module):
    def __init__(self, arch):
            super().__init__() 
            self.body = create_body(arch)
            self.head = create_head(nf=512*2, nc=8, lin_ftrs=[256])
    
    def forward(self, img):
           x = self.body(img)
           return self.head(x)

How do I do that?
I have tried all the variants of the following and it doesn’t seem to work:

model = SiameseNetwork(arch=models.resnet34).cuda()
loss_func = ContrastiveLoss(margin=margin)
siam_learner = Learner(data, model, loss_func=loss_func, model_dir=PATH, metrics=[acc_1, acc_2, acc_3])

…
…

# Create an instance of SingleNetwork 
single_model = SingleNetwork(arch=models.resnet34).cuda()
# Load the weights
single_model.body.load_state_dict(siam_learner.model.body.state_dict())
single_model.head.load_state_dict(siam_learner.model.head.state_dict())
# Create a Learner with a SingleNetwork instance
single_learner = Learner(single_data, single_model, model_dir=PATH)

Any clues anyone?

1 Like

OK, I just did some further checks. It looks like the weights are being copied but the output from the siamese network for a pair of images is different from the euclidean distance between the outputs of the same images through the “single network”. Should anyone have tried this before, I’d be grateful for any hints.

1 Like

Hi, gautam_e! Great!

Could you share how you solved the DataLoader part? How to load 2 images?

Thanks,
Fabio.

I used the Custom ItemList tutorial in the docs as well as a kernel on kaggle to guide me. That should help you too.

1 Like

Thank you, very much!