Loading in pre-trained model arch for another model type (siamese)

I’m trying to use a model that is not defined in the torchvision.models module as the architecture for a custom siamese network I’m trying to use.

model = Resnet50_ft_dag()
archi = torch.load("./resnet50_ft_dag.pth")
model.load_state_dict(archi)
class SiameseNetwork(nn.Module):
    def __init__(self, arch=model):
        super().__init__() 
        self.cnn = create_body(arch)
        self.head = nn.Linear(num_features_model(self.cnn), 1)
        
    def forward(self, im_A, im_B):
        # dl - distance layer
        x1, x2 = seq(im_A, im_B).map(self.cnn).map(self.process_features)
        dl = self.calculate_distance(x1, x2)
        out = self.head(dl)
        return out
    
    def process_features(self, x): return x.reshape(*x.shape[:2], -1).max(-1)[0]
    def calculate_distance(self, x1, x2): return (x1 - x2).abs_()

Here is the code on the network that I’m using, and I want to use the model resnet50_ft_dag, with code to define the model arch and the weights saved in a .pth file at http://www.robots.ox.ac.uk/~albanie/pytorch-models.html.

Basically, my question is, if I have a model that I can load can I then create an architecture that will work in create_body(arch)?

Thanks in advance! :slight_smile:

arch should be Callable but here you are passing a pytorch model. create_body just loads with model from the Callable arch and cut the final layers. Since you are already loading the model, I think what you should do is: Create your own custom_create_body function.

def custom_create_body(model):
    ll = list(enumerate(model.children()))
    cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    # or you can set cut=-2 if you are using just resnet50_ft_dag
    return nn.Sequential(*list(model.children())[:cut])

class SiameseNetwork(nn.Module):
    def __init__(self, arch=model):
        super().__init__() 
        self.cnn = custom_create_body(arch)
        self.head = nn.Linear(num_features_model(self.cnn), 1)
    ....
1 Like

Amazing, thank you!