Training RNNs as fast as CNNs

An interesting adaptation of the recurrent unit to create something that’s more parallelizable. By breaking the dependence of the forget vector on the hidden state at t-1 they’re able to compute them in parallel, improving training by 5x+ over cuDNN LSTM implementations.

They present it as an SRU or ‘simple recurrent block’ but the concept is applicable to any recurrent block, and they provide source code for their work in pytorch. If you work in RNNs it’s probably worth checking out. They benchmark their block on a variety of tasks

3 Likes