I’ve been experimenting with that, this (below) is how I managed to get the model with two inputs x1 and x2.
The tricky part for me was to make the datablock api give a pair of images. I don’t know if there is an easy way of doing that.
I think it would be interesting if there was a way to combine multiple datablock objects. There are several applications that require multiple inputs like these Siamese Networks, data distillation or more complex networks combining different kinds of input.
class SiameseResnet34(nn.Module):
def __init__(self):
super().__init__()
self.body = create_body(models.resnet34(True), cut=-2)
self.head = create_head(2048, 1, [512])
def forward(self, x1, x2):
out1 = self.body(x1)
out2 = self.body(x2)
out = torch.cat((out1, out2), dim=1)
out = self.head(out)
return out.view(-1)