Hello,
I am having trouble understanding something, I am trying to build a model like the Siamese:
class Model(Module):
def __init__(self, encoder, head):
self.encoder,self.head = encoder,head
def forward(self, x1, x2):
ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1)
return self.head(ftrs)
with 2 inputs.
If I use the datablock approach proposed on the siamese tutorial, it does not work. Because the batch that is created has lenght 2, it is on the form:
((img1, img2), y)
instead of form:
(img1, img2, y)
More generally, i want to understand where I should modify my code to * the batch into two items.
What I found, is that the learner does this in the _split
function, doing self.xb = b[:i]
, if I do self.xb[0]
i get the desired behavior.
I want to be able to create custom “Tuple” inherited classes, and clean-up my models using multiple inputs.