TabNet with fastai v2

There is an interesting paper for using Attention-based network for tabular data

The kind person created a Pytorch implementation here

I’ve adaped the code from the repo to fastai v2.0. It seems to work, i hope i didn’t screw the code in the process. Feel free to try it and give a feedback.

Also, is there a benchmark for tabular data in fastai?


Awesome!!! I was eventually going to do this :wink: I’ll take a look at it :slight_smile:


On a bit of a side note, one thing I’d like to do is take all their baselines and run it through fastai’s tabular model for a fair comparison too have you tried this yet for the datasets they looked at? I mention this as ADULTs is a particularly challenging dataset that doesn’t do much (or can have much improvement) based on the recent studies. (The highest I’ve seen was GBT with 89% I think?)

I haven’t tested it vigorously on datasets, yet.

1 Like

Cool, it’d be nice if we could compile a list to try. I plan on doing TabNet, NODE, and DeepGBM in the study group and that would certainly help give us a true picture :slight_smile:


Mostly what I’ve been looking at is A: accuracy, B: overall training time, and C: Model complexity. As (usually) the nice thing with tabular models they’re fast at training and at inference. Let me know your thoughts :slight_smile:

The paper mentions two advantages of TabNet. One of them is accuracy for hard / special cases. The other is interpretability (I have not adapted the code for interpretability).

1 Like

Yes, the interpretability is another bit. I’d certainly be interested in seeing this, as it was a bit hard to wrap my head around when I read the paper initially (it’s been a little bit, been focused on vision for the last few months). Could you try to basically sum up what they described? :slight_smile: (I’ll also revisit the paper here once all the vision notebooks are done)

interested to see where this goes , hi again Zach :wink:

1 Like

He’s alive! Welcome back other Jeremy :wink:

1 Like

There is now :wink:


@grankin I believe that something may be wrong. Nope, just they train for a very long time.

learning rate of 0.02 (decayed 0.9 every 10k iterations with an exponential decay) for 71k iterations.

Sweet holy moley. 71,000 batches?! Let me try again…

By the way, if my math is right, that is the equivalent of 11,000 epochs.

(They have a 4096 batch size, 4096 * 71,000 iterations / 25000 rows)

Running those now

Surpassed fastai at ~1600 epochs

Just had a quick thought too, let me see…

Well, it won’t let me post so, let me post an exciting update here:

Alright so, half way through training I got a harebrained idea. This arch doesn’t change much, right? Can I use transfer learning here. I did a quick experiment. 50 epochs with the poker dataset, then 50 epochs transfered onto Adults. Here is what I found.

To reach above 80% accuracy:

  • Non-Transfer learning: epoch 31
  • Transfer learning: epoch 11

Finishing accuracy:

  • Non-Transfer: 78.5-79%
  • Transfer: 81-82.5%

There is certainly promise here.

Next is to try the reverse

Reverse didn’t make much of a difference, unsure why that’s the case but it is!

Still need to test, freezing layers: Tabular Transfer Learning and/or retraining with fastai

So I got the code for freezing and whatnot figured out. I’m training the poker model tonight then go from there tommorow morning. (Or this morning? I think it’s like 3:30 am)

Hmmmm… I could not recreate their accuracy they achieved for Poker Hand… :frowning:

@grankin even running in their source code I was unable to match it (PyTorch), I got early stopping ~epoch 207, which had an accuracy of 54%. (I still wasn’t able to achieve this on our version but that is a far cry from 99%)


@grankin, I was just informed that someone was able to achieve just 61% accuracy on the poker dataset using the original Keras implementation… so we may not be as wrong as I was thinking…

It strange practice for Google to claim results and not sharing the reproducible code. I’ve submitted an issue.


Thanks for doing this! I added a bit more :slight_smile: hopefully this gets sorted out

1 Like

Hello guys, I’m one of the author of the pytorch-tabnet repo

I think it’s a great idea to test on different datasets different algorithms, I’m eager to know about the results!

My only concern with your fastai wrapper is that it does not use our repo but a copy/paste of it, so for example the bugfix we released this week is not taken into account. Is there a way to make a fastai wrapper that uses directly our repo? Moreover, our implementation is scikit compatible so I am not sure about the need for a fastai wrapper but that’s another story :slight_smile: By the way if you see some improvements that could be done with the code, please let us know! Feel free to contribute as well!

I will probably do some experiments myself with pytorch-tabnet I’ll let you know if I can share some results!


A bit off topic but I just wanted to say thank you for porting the library to PyTorch. I have made some kind of simple CLI on top of the original TensorFlow implementation to run benchmarks and hyperparameter tuning on different datasets. The PyTorch version with sklearn interface will make this more straightforward. I’ll also be happy to report results to compare with other approaches including standard FastAI and FastAI TabNet. Nice work!


@Optimo I’m not sure this is possible as fastai’s dataloaders make our inputs differently than what it normally is (I believe), and do our input is 2 different Tensors (one with categorical and one with numerical). You can see this in the forward() method, as it’s very different. Also thank you for porting it over for us :slight_smile: (also calling it a copy/paste kinda kills the effort in which @grankin went into doing it, as I’m sure it’s a bit more than that, else it would just be the direct thing)

1 Like

Hello, Optimo, thank you for the great work! I’ve updated the repo and package to include your bugfix.

1 Like

@Optimo or @mikaelh you may be more familiar with this, do you know any other datasets that the literature uses for comparisons? :slight_smile: There’s not a lot of literature with using fastai’s technique as a comparison (hence this thread) and I was wondering if there was more I missed (besides the salary dataset)

  • OpenML has a lot of datasets in general, and about 40 of those were used in an AutoML benchmark.

  • CatBoost has some reference datasets (it’s also shown in that repo how they are prepared)

  • Of course, Kaggle Datasets and the UCI Machine Learning Repository are good places to look.

  • I know there have been some mega-benchmarks with hundreds of datasets (for example one paper that concluded that random forests are the best classification algorithm; it’s a couple of years old though). The SeLU paper reports results on 121 UCI datasets.

1 Like