Understanding Tabular architecture

Hi, I’m going through part 1 and I am struggling to understand entity embeddings.

For a resnet architecture, the first layer has 784 dimensions for each pixel of the 28x28 image. (skipping bias for the sake of the discussion)
So the weight matrix would be n by 784 (n is hidden layer size)

For collaborative filtering, say we have 5 users and 15 movies. We have an ‘embedding’ dimension say 10, and a dot product between the two gives us a rating for each user/movie pair. This is also clear to me.

But to combine the two, in tabular architecture’s input layer, assuming we have 3 categorical variables (embedding size 10 each) and 70 continuous variables.

How does the weight matrix now look like to get to the first hidden layer? I’m not able to visualise how the ‘numbers’ go to the hidden layer. Could someone guide me to a resource that could help?

Tabular architectures uses a simple ANN, in your example you say there are 70 continuous variables, imagine a neural network input layer with 70 nodes for these.
Now since there are 3 categorical variables and each has embedding size of 10, you have 3 lists with 10 numbers each, stitch them all together to form a single list of 30 numbers each occupies a input node.
70 nodes for continuous variables and 30 nodes for 3 categorical variables, so in total you have a 100 node input layer.

2 Likes

Thanks @vijayabhaskar for the swift response.

So as I understand, input layer will be of 100.

But Jeremy emphasises on the ‘activations’ and ‘parameters’ and that only ‘parameters’ get updated with each backward pass. So how do the embeddings get updated if they’re a part of the input layer?

I suggest you to read this blog which has a diagram that will help you visualize https://medium.com/@george.drakos62/decoded-entity-embeddings-of-categorical-variables-in-neural-networks-1d2468311635 open in incognito if you can’t view the article. Yes, inputs are not trained but these 30 numbers are not inputs they are the outputs of these embedding layers.

3 Likes