Size of embedding for categorical variables

How the size of embedding for categorical variables is chosen for Tabular data NN inputs?

3 Likes

So it looks like it changed since I last looked. Previously, it was:

def emb_sz_rule(n_cat:int)->int: return min(50, (n_cat//2)+1)

Now it looks like it’s the following:

def emb_sz_rule(n_cat:int)->int: return min(600, round(1.6 * n_cat**0.56))

I’m not sure of the reason why, but you can dig into the source code here:

3 Likes

I’d be also interested in knowing where this number come from?
@jeremy or @sgugger maybe?

Empirical testing (grid search): Jeremy says that in a couple of lessons.

For course 2018 - Take a look at @hiromi post on that:

The rule of thumb for determining the embedding size is the cardinality size divided by 2, but no bigger than 50.


2 Likes

Thanks @ste

Yeah I remember that but the 1.6* n_cat**0.56 seems pretty specific that I wonder if there isn’t any research paper behind it.

2 Likes

Didn’t know about paper on that.
I usually start with “Jeremy Defaults” (or previous / similar projects) and adjust them, increasing or reducing that number according to the accuracy of the validation.

1 Like