I’m working with models for text generation. Previously I’ve used LSTM-based models to great effect. I’m interested in using Transformer models, but I’ve found standard transformers to be significantly less efficient for autoregressive text generation, requiring more time and memory compared to RNNs for a similar number of parameters (see example below).

Does anyone know of a good transformer variant for efficient text generation? I know there are some transformer variants that are designed for autoregressive use, but I can’t find clean implementations on Github.

Generation example:

LSTM generation usually goes something like this:

```
x = torch.zeros((n_samples, 1)) # LSTM input is (bs, sl)
x += vocab.bos # bos token
x = x.long().cuda()
model.reset() # reset hidden state
new_idx = []
for i in range(max_sequence_length):
output = model(x)[0]
probs = F.softmax(output[:,-1], dim=-1)
new_sample = torch.distributions.Categorical(probs).sample()
new_idx.append(new_sample)
x = new_sample[:,None]
preds = torch.stack(new_idx).T
```

Here’s the same thing for transformers:

```
x = torch.zeros((1,n_samples)) # transformer input is (sl, bs)
x += vocab.bos # bos token
x = x.long().cuda()
new_idx = []
for i in range(max_sequence_length):
output = model(x)[0]
probs = F.softmax(output[:,-1], dim=-1)
new_sample = torch.distributions.Categorical(probs).sample()
new_idx.append(new_sample)
x = torch.cat([x, new_sample.unsqueeze(0)], 0) # this line is the big difference, you need to concatenate the full sequence for the input to the next iteration
preds = torch.stack(new_idx).T
```

The big difference between LSTMs and transformers for autoregressive generation is with LSTMs you can just feed each subsequent token (`x = new_sample[:,None]`

), while transformers require you to feed back in the full sequence (`x = torch.cat([x, new_sample.unsqueeze(0)], 0)`

). In practice I’ve found transformers require an order of magnitude more GPU memory for the same generation task.

I’m hoping to find a transformer variant that can do autoregressive generation with a lower compute overhead.