Thought about collab filter for classification

Dear All,

I am interested in using the library for collaborative filtering for some classification problem. This modification was discussed in v2 of the course and I see in fact the option to pass a loss_fn to fastai.collab.get_collab_learner has been included allowing one to pass (say) cross entropy loss instead of RMSE.

However to my eyes (and attempts, forgive me if I’m missing something) at the moment ‘get_collab_learner’ will fetch the same embedding_dot_bias model regardless of the loss which is not suitable.

One option is to adjust the dimension of the output to equal the # of classes, currently the forward method returns an object of shape (batch_size, 1); However to do so it’s not enough to add a nn.Linear(1,num_classes) to the end of the forward, since this will force the architecture to only predict membership to one of two class regardless of input (for a fixed set of parameters to the linear layer). Say the parameters are [a_1 … a_[num_classes]]):
Consider argmax_i (log(softmax(a_i*x)) (predicted class) will be the class corresponding to the largest (in abs) positive parameter a_j for all x>=0 and largest (negative parameter) for all x<0. This behavior seems undesirable.

Instead we could do the approach of the “embedding net” of v2 lesson on collaborative filtering, briefly the embeddings are set up the same way but the dot product is replaced by concatenation of the two embedded vectors, followed by one (or more) linear layers separated by activation/regularization with the chain ending in dimension=number classes. I have been using this approach for some of my collab filtering classification problems to reasonable success and would like to implement it into this library so I can take advantage of all the nice functionality fastai provides. My thought is to add a ‘regress=True’ parameter to get_collab_learner which points the model that get_collab_learner fetches to one of two (the embed-dot or embed-net-classifier) along with perhaps some parameters on the structure of the net for the classifier case.

Would appreciate any thoughts on the matter, as well as some advice on how best to proceed. I am new to the community (and have searched this page for collab filter questions not seeing anything) and not sure what the protocol is on implementing an idea and sharing it? Look forward, to hearing back, thanks!

1 Like