Hi,
As per my understanding forward is an special function which gets called by PyTorch to do forward pass calculations. I ma confused regarding the signature of this method.
In lesson 5 notebook we see two different implementations of the method.:
class DotProduct(nn.Module):
def forward(self, u, m): return (u*m).sum(1)
class EmbeddingDot(nn.Module):
...
...
def forward(self, cats, conts):
users,movies = cats[:,0],cats[:,1]
u,m = self.u(users),self.m(movies)
return (u*m).sum(1).view(-1, 1)
I have the following doubts:
- In the second implementation what is meant by arguments
cats
andconts
? - And how are they passed to the forward method?
- Why cant we directly pass users and movies vectors as in the first implementation?
Thanks.