Extracting embeddings from collab_learner with use_nn=True?

Hi all. This is my first post - apologies if it’s in the wrong place.

I have used the collab_learner to to find embeddings for a collaborative filtering task. I want to try out the use_nn option and have trained a learner as such:

learn = collab_learner(data, use_nn=True, emb_szs={'user': 50, 'item':50}, layers=[256, 128], y_range=(-1, 1))

I cannot work out to extract the embeddings. The method for the dot product version of this learner fails:


Where allitems is an array of the unique item ids going into the learner.

AttributeError: ‘EmbeddingNN’ object has no attribute ‘i_weight’

Looking at the source code it seems the weight method is only set up for the use_NN=False situation. But the collab_learner does not have the embeds attribute that the underlying TabularModel possesses.

Any ideas?!?

P.S. the document string for the weight method of the collab_learner is incorrect, it’s a copy paste of the bias method.

1 Like

I have the same problem. Did you find a way to solve it?

Hello, yep.

The collab_learner is just a special case of tabular_learner so inherits the general learner methods. Similarly the collab model is just a special case of the tabular model.

To get the weights you do:
item_w = learn.model.embeds[1].weight[1:]
user_w = learn.model.embeds[0].weight[1:]

This just provides a numpy array. To be sure you know which item and user ids correspond to which weight, you get the item and user class lists from here:

allitems = learn.data.train_ds.x.classes['item'][1:]
allusers = learn.data.train_ds.x.classes['user'][1:]

Thus the 0th element in the item_w array is the weight for the item id specified in the 0th element of allitems, etc.

The [1:] is because the 0th element is for “#na” (i.e. for item/user ids which aren’t in the training set).

Happy to be corrected if there’s a better way!


Hi Ponty! Thanks for your answer.

Do you know what happens for the users when we only specify embed size for items in EmbeddingNN? Will the users be implicitly embedded using the same number of latent factors or will it be an aggregate of the user’s purchases?

Is there a similar way to extract the biases and their key?

The train dataset will be missing users from the validation set, is there a way to reconcile that if you’re trying to get every user/rating combination possible?

Thank you!