Unable to compute intermediate layer's output to get embeddings

Hi all, I was going through 09_tabular.ipynb. There I want to use neural network embeddings in a Random Forest. I tried various ways but was unable to get embeddings from a NN.

The NN was having layers as [500, 250, 20], and was trying to compute 20-dimensional embeddings.

I tried getting embediings on a simple NN with only 1 cont and 0 cat variable by:
embs = learn.model.layers[0:2](dls.train.one_batch()[1])
(0:2 because want to get output from 2nd layer)

The issue with computing embeddings is that there are both cont and cat variables, and so dls.train.one_batch() gives 3 tensors, rather than 2. I tried concatenating them but that also didn’t work. I tried a lot of ways but got stuck anyways :wink:

WIll be very thankful if someone could guide or rather share his code snippet.

Thanks in advance.

Hi! I also tried this but ran into problems as well.

I did get the embeddings though, like this:
embeds = list(learn.model.embeds.parameters())

Maybe this helps? Let me know if you have any success with this!

We should be doing this with hooks instead, similar to GradCAM. Let me play with this a bit some today

We have our answer! :smiley:

Here’s what I did, presume this is your fastai tabular model for the Adult Sample dataset. I want to grab the intermediate layer output of 100, when our FC layers are [200,100]. I do the following:

with hook_outputs(learn.model.layers[:-1]) as hooks:
    _ = learn.model(cat,cont)

I can then access these intermediate layer activations by doing:
hooks.stored[-1]. This will be of shape [64,100] in this particular example :slight_smile:

For another example, if we want to get the raw embeddings to use with RF, we do:

with hook_outputs(learn.model.embeds) as hooks:
    _ = learn.model(cat,cont)
embeds = torch.cat(list(hooks.stored), 1)

Hope this helps @tsm_tau and @johannesstutz

8 Likes

this is so elegant… thanks

Thanks for this, worked very well and easily, also got to learn about Hooks. Thanks again.

Oh yes, so in the chapter it’s talking about using ‘categorical embeddings’ trained from a nn, rather than a raw categorical column. I was thinking that if say nn has layers as [500,250,10] then we want to use final layer 10-dim output as an extra feature, along with raw categorical columns. My bad, Thanks

Oh this looks great. Good opportunity to finally learn about hooks. Thanks Zach!

Not sure if that would work or not, but sounds very interesting!

Hi, thank you so much for the solution and your time. It might be a stupid question, but would you please clarify what cat, cont exactly are in _ = learn.model(cat,cont)?

As I understand it’s x_cat, x_cont for forward(self, x_cat, x_cont=None).

I put learn.dls.train.xs[cat_nn],learn.dls.train.xs[cont_nn] and got an errors. Basically, whatever I put there I got errors. Maybe I miss something in concept.
I working on 09_tabular.

Thank you in advance.

Sure!

So all this is is a batch of data from our DataLoader. Fully fleshed out this pipeline would look something like:

cat, cont, y = dl.one_batch()
with hook_outputs(learn.model.embeds) as hooks:
    _ = learn.model(cat,cont)
embeds = torch.cat(list(hooks.stored), 1)

We need to get it from the DataLoader as fastai’s model is expecting tensors.

What you did there was pull from the Dataset, or TabularPandas, which aren’t tensors. That’s just a Pandas object with some more fun fastai things

What is cat_nn and cont_nn in your case?

Thank you so much, everything clear. My bad, I pushed TabularPandas to it. A strange mistake, in a bunch of code I was sure that these were tensors.
I have tried to get one big Emmbering matrix (412_698 rows in my case - all training data) for all categorical and continuous columns and use it as a training dataset for the Random Forest model.

It remains to figure out how to get everything at once (not just tensors for cat and cont for one_butch() which is limited by batch size - n_rows, for all data at once).

Thank you for your time.

Hi, is there a method to get the cat, cont from an entire DataLoader?