Extract categorical embeddings tabular

Hi there

I am looking to extract embeddings for categorical variables in a tabular learner to then transfer to XG boost.

It’s easy enough to get the embedding matrices with something like this:
embedding_matrices = [embed.weight for embed in learn.model.embeds]

and easy enough to get the classes with something like this:
category_names = [k for k,v in learn.dls.classes]

The order of classes and embedding matrices match. How can I be sure that the order of variables within each class is the same as the order of rows in the corresponding embedding matrix? I tried looking at the source code and wasn’t totally sure.

Any thoughts/comments/tips woudl be greatly appreciated

Thanks in advance

Jeff

4 Likes