Tabular learner embedding input size

I am looking at the Embedding layers in the Rossmann notebook. I noticed that the sizes of the inputs to the Embeddings are one bigger than I expected. Here are the embeddings and associated names:

{('Assortment', Embedding(4, 3)),
 ('CloudCover_na', Embedding(3, 3)),
 ('CompetitionDistance_na', Embedding(3, 3)),
 ('CompetitionMonthsOpen', Embedding(26, 10)),
 ('CompetitionOpenSinceYear', Embedding(24, 9)),
 ('Day', Embedding(32, 11)),
 ('DayOfWeek', Embedding(8, 5)),
 ('Events', Embedding(22, 9)),
 ('Month', Embedding(13, 7)),
 ('Promo2SinceYear', Embedding(9, 5)),
 ('Promo2Weeks', Embedding(27, 10)),
 ('PromoInterval', Embedding(4, 3)),
 ('Promo_bw', Embedding(7, 5)),
 ('Promo_fw', Embedding(7, 5)),
 ('SchoolHoliday_bw', Embedding(9, 5)),
 ('SchoolHoliday_fw', Embedding(9, 5)),
 ('State', Embedding(13, 7)),
 ('StateHoliday', Embedding(3, 3)),
 ('StateHoliday_bw', Embedding(4, 3)),
 ('StateHoliday_fw', Embedding(4, 3)),
 ('Store', Embedding(1116, 81)),
 ('StoreType', Embedding(5, 4)),
 ('Week', Embedding(53, 15)),
 ('Year', Embedding(4, 3))}

You see that Week input is 53, month is 13, DayOfWeek is 8.

Am I missing something here? Is there a bias term?

Ah I think I have figured out what’s going on. It is adding an additional input for missing variables. If you look at the classes in the TabularList, data.x.classes:

OrderedDict([('Store',
              array(['#na#', '1', '2', '3', ..., '1112', '1113', '1114', '1115'], dtype='<U21')),
             ('DayOfWeek',
              array(['#na#', '1', '2', '3', '4', '5', '6', '7'], dtype='<U21')),
             ('Year', array(['#na#', '2013', '2014', '2015'], dtype='<U21')),
             ('Month',
              array(['#na#', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'], dtype='<U21')),
             ('Day',
              array(['#na#', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19',
                     '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31'], dtype='<U21')),
             ('StateHoliday', array(['#na#', False, True], dtype=object)),
             ('CompetitionMonthsOpen',
              array(['#na#', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18',
                     '19', '20', '21', '22', '23', '24'], dtype='<U21')),
             ('Promo2Weeks',
              array(['#na#', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18',
                     '19', '20', '21', '22', '23', '24', '25'], dtype='<U21')),
             ('StoreType', array(['#na#', 'a', 'b', 'c', 'd'], dtype=object)),
             ('Assortment', array(['#na#', 'a', 'b', 'c'], dtype=object)),
             ('PromoInterval',
              array(['#na#', 'Feb,May,Aug,Nov', 'Jan,Apr,Jul,Oct', 'Mar,Jun,Sept,Dec'], dtype=object)),
             ('CompetitionOpenSinceYear',
              array(['#na#', '1900', '1961', '1990', '1994', '1995', '1998', '1999', '2000', '2001', '2002', '2003', '2004', '2005',
                     '2006', '2007', '2008', '2009', '2010', '2011', '2012', '2013', '2014', '2015'], dtype='<U21')),
             ('Promo2SinceYear',
              array(['#na#', '1900', '2009', '2010', '2011', '2012', '2013', '2014', '2015'], dtype='<U21')),
             ('State',
              array(['#na#', 'BE', 'BW', 'BY', 'HB,NI', 'HE', 'HH', 'NW', 'RP', 'SH', 'SN', 'ST', 'TH'], dtype=object)),
             ('Week',
              array(['#na#', '1', '2', '3', ..., '49', '50', '51', '52'], dtype='<U21')),
...
2 Likes

Great find! I’m trying to learn what I can about the embeddings now as a few of my professors want to know what is happening there. Do you know what a <U21 datatype is in python? I’m yet to learn about that.

Sure, this is the low level type used to represent the data in NumPy.
‘U’ is unicode: so that’s a string datatype in python3+.
‘<’ means ‘little endian’. That’s the convention of how the individual bits of the data are encoded which is CPU dependent. Almost all CPUs are little endian these days.
I think the ‘21’ is 21 character string type? I’m not sure.

Some of the arrays are of type ‘object’ even though they are also just strings. I don’t know why that is.

1 Like

This is a nice page on numpy types: https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.dtypes.html

Thanks for the explaination! It clarified a lot for me! Do we know how it’s determining the second size? I get that the first portion of the layer is just x categories + 1, but how is it figuring out the second one? When I have time I want to explore the documentation to see that but curious if you had discovered that explaination yet.

1 Like

Yeah I had to dig into the source code for that one. My impression is that there isn’t an exact science to determining encoding sizes. In the Rossmann Kaggle paper on entity encoding they said they experimented until they found values that worked. There are some ‘rules-of-thumb’ that people have invented through experimentation I guess, dependent upon the number of classes. In the source code of fast.ai in tabular/data.py the code used is:

#def emb_sz_rule(n_cat:int)->int: return min(50, (n_cat//2)+1)
def emb_sz_rule(n_cat:int)->int: return min(600, round(1.6 * n_cat**0.56))

You see there are two rules, and one is commented out (an old version?). I don’t know where either of these came from though. :slightly_smiling_face:

2 Likes

Interesting! I may experiment with this and see with various datasets but I’m glad we have a general rule we can follow!

1 Like