Still trying to wrap my head around the an LSTM implementation

Ok so I’ve spent entirely too many hours trying to get this to work. I’m appealing to smarter people than me!

I’m trying to get the LSTMCell from Chapter 12 to work in a LSTM Module. But I keep running into issues where my tensors aren’t the right sizes to stack. I’ve tried changing the shapes of my hidden and cell state, but then when I stack my hidden state with my input the tensor just keeps getting bigger. I’m hoping it’s obvious to someone else how to configure this?

Here are my classes:

bs=64

class LSTMCell(Module):
    def __init__(self, ni, nh):
        self.forget_gate = nn.Linear(ni + nh, nh)
        self.input_gate = nn.Linear(ni + nh, nh)
        self.cell_gate = nn.Linear(ni + nh, nh)
        self.output_gate = nn.Linear(ni + nh, nh)
        
    def forward(self, input, state):
        h,c = state
        # Stack our input with the previous hidden state
        h = torch.stack([h, input], dim=1)
        # Linear layer learns what to forget, then activated by sigmoid
        forget = torch.sigmoid(self.forget_gate(h))
        # Since forget consists of scalars between 0 and 1, we multiply this result by the cell state
        # to determine which information to keep and which to throw away. Values close to 0 are thrown away
        # Values close to 1 are kept
        c = c * forget
        # Input gate combines with the cell gate to update the cell
        inp = torch.sigmoid(self.input_gate(h))
        # Linear layer activated by tanh
        cell = torch.tanh(self.cell_gate(h))
        # Cell state is updated by the results of the input gate times the cell gate
        c = c + inp * cell
        # Output gate determines which information from the cell state to use to generate output
        out = torch.sigmoid(self.output_gate(h))
        # New hidden state is the results of the output gate combine with the tanh of the cell state
        h = out * torch.tanh(c)
        # Outputs the new hidden state along with the cell state. Seems redundant?
        return (h,c)

class LSTM_scratch(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = LSTMCell(n_hidden, n_hidden)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = torch.zeros(bs, n_hidden)
        self.c = torch.zeros(bs, n_hidden)
        
    def forward(self, x):
        h,c = self.rnn.forward(self.i_h(x), (self.h, self.c))
        self.h = h.detach()
        self.c = c.detach()
        return self.h_o(self.c)
    
    def reset(self):
        self.h.detach()
        self.c.detach()

learn = Learner(dls, LSTM_scratch(len(vocab), 50),
                 loss_func=CrossEntropyLossFlat(),
                 metrics=accuracy, cbs=ModelResetter)

learn.fit_one_cycle(5, 1e-2)

which gets me the error “RuntimeError: stack expects each tensor to be equal size, but got [64, 50] at entry 0 and [64, 16, 50] at entry 1”. So I think I’m getting confused about the dimensions of my hidden state and how to stack that with my input? I’ve tried making the dimension [64, 16, 50] to match with my input but then the next time I’m stacking I get a similar error trying to stack [64, 16, 50] with [64, 30, 50], which I don’t really understand either… I would have thought it would be [64, 32, 50] if anything.

Is it obvious to anyone how I’m butchering this? I’ve been stuck trying to understand this for days. I just want to move on!

Thanks a lot

Hi Gannan. I really hate it when I ask a question and some jerk posts a reply that does not answer the actual question I asked. At the risk of being a jerk, I’ll suggest that you do move on! It does not serve you to get stuck here, and IMHO there’s little benefit at this point in implementing LSTMCell manually. After you finish the course, understand Python better, and have learned how to use a debugger, you can come back to implement LSTMCell with ease.

If you want to investigate RNNs right now, I suggest you drop in the PyTorch implementation of LSTMCell. It has already been thoroughly debugged and accelerated with CUDA. Copy someone else’s already working RNN module that uses LSTMCell, and learn from it. Then modify it bit by bit to fit your needs.

As for your code, I only skimmed it. The prominent issue is that you are not calling LSTMCell repeatedly on the elements of a single sequence returned by the DataLoader. That will cause whatever is generated by dls to be treated as individual sequence elements. Probably not what you intended. :slightly_smiling_face:

Do you have the same sequence length for each example? Could this be causing your size mismatch issues?

For debugging tensor shapes, I’ve had good experience with torchsnooper. You can give it a shot to debug your code.

There has been another library recently - tensor-sensor which can do the same as well.

It would also help if you can share a Colab or a GitHub link, so that others can try out your code.

2 Likes

Hi Gannon,

if you are motivated to try again: I think the implementation in the book has a bug, it should use torch.cat instead of torch.stack, as described in this post.

I hope this helps, should you try again :slight_smile:

3 Likes

That makes way more sense! Thanks for reaching out on such an old post. I had already given up thinking I’m the crazy one. Cheers

2 Likes

Got gonnan’s code running thanks to the hints from Pomo (loop over each word) and johannesstutz (cat instead of stack)

class LSTMCell(Module):
    def __init__(self, ni, nh):
        self.forget_gate = nn.Linear(ni + nh, nh)
        self.input_gate  = nn.Linear(ni + nh, nh)
        self.cell_gate   = nn.Linear(ni + nh, nh)
        self.output_gate = nn.Linear(ni + nh, nh)

    def forward(self, input, state):
        h,c = state
        h = torch.cat([h, input], dim=1)
        temp = self.forget_gate(h)
        forget = torch.sigmoid(temp)
        c = c * forget
        inp = torch.sigmoid(self.input_gate(h))
        cell = torch.tanh(self.cell_gate(h))
        c = c + inp * cell
        out = torch.sigmoid(self.output_gate(h))
        h = out * torch.tanh(c)
        return h, (h,c)

class LSTM_scratch(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = LSTMCell(n_hidden, n_hidden)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = torch.zeros(bs, n_hidden)
        self.c = torch.zeros(bs, n_hidden)
        
    def forward(self, x):
        outs = []
        for word in range(x.shape[1]):
            h, (h,c) = self.rnn.forward(self.i_h(x[:,word]), (self.h, self.c))
            self.h = h.detach()
            self.c = c.detach()
            outs.append(self.h_o(self.c))
        return torch.stack(outs, dim=1)
    
    def reset(self):
        self.h.detach()
        self.c.detach()

learn = Learner(dls, LSTM_scratch(len(vocab), 50),
                 loss_func=CrossEntropyLossFlat(),
                 metrics=accuracy, cbs=ModelResetter)

learn.fit_one_cycle(15, 1e-2)