Forward Function

In Chapter 8 Collaborative Filtering while Implementing it from scratch. In Class DotProduct
forward function How is the parameter x is passed and what data does parameter x contain

class DotProduct(Module):
def __init__(self, n_users, n_movies, n_factors):
    self.user_factors = Embedding(n_users, n_factors)
    self.movie_factors = Embedding(n_movies, n_factors)
    
def forward(self, x):
    users = self.user_factors(x[:,0])
    movies = self.movie_factors(x[:,1])
    print(users,movies)
    return (users * movies).sum(dim=1)

Hey Syed!

In pytorch (and in general in the context of DL) “forward” is usually the equivalent of “_ _call _ _” - meaning it’s the function that gets called when you run the model on some input. model(x) basically calls model.forward(x), and so whenever you see x as an argument in the forward function, it’s always the input. As you can see in the function, it is expected to contain two columns - the first is users and the second is the movies.

1 Like