How to extract outputs and hidden states from awd-lstm

I am wresting with how to do fast text generation with my awd-lstm model. I know the code in my brain would look something like this:

learner = language_model_learner(data, AWD_LSTM, drop_mult=0.5, pretrained=False)
learner.load('..path to model')
#---number of samples to generate
batch_sample = 100 
#---max length of text for generation
max_seq_length=100
 #---array to store our actions
actions = torch.zeros((batch_sample, max_seq_length), dtype=torch.long).to(device='cuda')
#---create initial go array
go_int = learner.data.train_ds.vocab.stoi['GO']
xb = np.array( [go_int]*batch_sample )

 #---set my hidden states to 0
learner.model.reset()
for i in range(0, max_seq_length):
    output, hidden = learner.model(xb)
    m = torch.distributions.Multinomial(probs=output)
    action = m.sample()
    idx = action.nonzero()[:,1]
    actions[:,i] = idx
    xb = action

However, learner.model(xb) does not seem to return the output (token propbabilities) and the new hidden state. Looking at the code for the awd-lstm class, the forward function returns raw_outputs and outputs. I cannot seem to decipher what raw_outputs and outputs are. In addition, the shape of the actual array I get back when performing learner.model(xb) is a length 3 tuple whose components I don’t understand. As I am doing text generation, it is critical that the hidden states are only reset once at the beginning.

I have considered using the pred_batch function here with an array of all actions up to iteration i. However, this is extremely compute intensive as the entire sequence of hidden states needs to be recomputed for each iteration using pred_batch.

Could someone point out the best way for me to do this?

Note that language_model_learner returns a subclass of Learner (can’t remember the name but probably LMLearner that has a custom predict method which looks exactly like what you want, so you can use that function or go grab the code and tune it to your needs.

Thank you! Unfortunately, the predict function (I think this is what you are referring to) does not scale well. If one wants to generate 64 different texts, you would have to call predict 64 different times, or use n_words = max_length*64, both of which are inefficient. Here is my current solution:

def sample():
    learner.model.reset()
    go_int = learner.data.train_ds.vocab.stoi['GO']
    xb = np.array( [go_int]*batch_sample )
    xb = torch.from_numpy(xb).to(device='cuda').unsqueeze(1)
    actions = torch.zeros((batch_sample, max_seq_length), dtype=torch.long).to(device='cuda')
    for i in range(0, max_seq_length):
        output = learner.model(xb)[0].squeeze()
        output_probs = F.softmax(output, dim=1)
        m = torch.distributions.Multinomial(probs=output_probs)
        action = m.sample()
        idx = action.nonzero()[:,1]
        actions[:,i] = idx
        xb = idx.unsqueeze(1)

It took me a while to figure out that model(xb) returns the output function from the forward function of the linear decoder which are tensors called decoded, outputs, and raw_outputs. My current understanding is that decoded is outputs through the final linear layer, outputs is the sequential outputs of the rnn with dropout, and raw_outputs are the sequential outputs of the rnn with no dropout. Hence to get the probabilties for the next time step, I want to use decoded (index [0] of the returned tuple)! Please let me know if this looks incorrect.