Training SHA-RNN with fastai

Hi,

I have tried to train SHA-RNN from Steven Merity’s " Single Headed Attention RNN…" paper (https://github.com/Smerity/sha-rnn). I must admit it was neither easy nor successful (I now totally understand what Jeremy means by difficult to use medium-level API!), but I know there are some people on the forums who wanted to try this as well, so maybe they can share tips, or this could be a starting point for them (@edwardjross, @MicPie, probably some other people mentioned they are interested in SHA-RNN).

I took the path of least resistance and did something similar to what was described here. Training was quite resource-intensive (but not like BERT!). I was not able to use a large BPTT value, which I guess is the condition under which this particular architecture should perform better due to OOM. The results are not as good as with AWD-LSTM in terms of perplexity. Probably using some other hyperparameters could help as well.

The notebooks with the experiment are here: https://github.com/noise-field/fastai-sharnn/

3 Likes

This looks very interesting!

Unfortunately, I had no time to look into it due to a job transition…

However, back then I tried training with the entire scripts from the smerity SHA-RNN repo t get a baseline. Did you tried to run it with the same hyper parameters you used in your fastai setup to compare your fastai implementation to it?

PS: One thing I came across, which I think could be quite interesting for that setup to make it less GPU RAM hungry: Fixed Encoder Self-Attention Patterns in Transformer-Based Machine Translation (if those attention patterns have also a similar shape in LM training…)