vladgets
(Vlad Getselevich)
April 2, 2019, 3:04am
1
How the size of embedding for categorical variables is chosen for Tabular data NN inputs?
3 Likes
knesgood
(Kyle Nesgood)
April 2, 2019, 1:51pm
2
So it looks like it changed since I last looked. Previously, it was:
def emb_sz_rule(n_cat:int)->int: return min(50, (n_cat//2)+1)
Now it looks like it’s the following:
def emb_sz_rule(n_cat:int)->int: return min(600, round(1.6 * n_cat**0.56))
I’m not sure of the reason why, but you can dig into the source code here:
from ..basic_data import *
from ..data_block import *
from ..basic_train import *
from .models import *
from pandas.api.types import is_numeric_dtype, is_categorical_dtype
__all__ = ['TabularDataBunch', 'TabularLine', 'TabularList', 'TabularProcessor', 'tabular_learner']
OptTabTfms = Optional[Collection[TabularProc]]
#def emb_sz_rule(n_cat:int)->int: return min(50, (n_cat//2)+1)
def emb_sz_rule(n_cat:int)->int: return min(600, round(1.6 * n_cat**0.56))
def def_emb_sz(classes, n, sz_dict=None):
"Pick an embedding size for `n` depending on `classes` if not given in `sz_dict`."
sz_dict = ifnone(sz_dict, {})
n_cat = len(classes[n])
sz = sz_dict.get(n, int(emb_sz_rule(n_cat))) # rule of thumb
return n_cat,sz
class TabularLine(ItemBase):
3 Likes
zipp
(zipp)
May 24, 2019, 3:02pm
3
I’d be also interested in knowing where this number come from?
@jeremy or @sgugger maybe?
ste
(Stefano Giomo)
May 24, 2019, 3:32pm
4
Empirical testing (grid search): Jeremy says that in a couple of lessons.
For course 2018 - Take a look at @hiromi post on that:
The rule of thumb for determining the embedding size is the cardinality size divided by 2, but no bigger than 50.
https://medium.com/@hiromi_suenaga/deep-learning-2-part-1-lesson-4-2048a26d58aa
2 Likes
zipp
(zipp)
May 24, 2019, 3:39pm
5
Thanks @ste
Yeah I remember that but the 1.6* n_cat**0.56
seems pretty specific that I wonder if there isn’t any research paper behind it.
2 Likes
ste
(Stefano Giomo)
May 24, 2019, 4:01pm
6
Didn’t know about paper on that.
I usually start with “Jeremy Defaults” (or previous / similar projects) and adjust them, increasing or reducing that number according to the accuracy of the validation.
1 Like