LSTM Cell Question

This is really not that important but it’s driving me crazy. I am looking at the following code for the LSTM from lesson 12 (lecture 8):

h is a torch.stack of the tensors h and input. h and input are both two dimensional tensors, so the resulting h should be three dimensional since stack adds a dimension. In contrast, torch.cat does not add a dimension.

self.forget_gate is also two dimensional. So how can the matrix multiplication in self.forget_gate(h) be possible? Are we not multiplying a three-dimensional tensor by a two-dimensional tensor?

In the refactored LSTM later in the notebook, I don’t see the same shape problem. What am I missing?

1 Like

Indeed you are right. Here is a corrected version.

class LSTMCell(Module):
    def __init__(self, ni, nh):
        self.forget_gate = nn.Linear(ni + nh, nh)
        self.input_gate  = nn.Linear(ni + nh, nh)
        self.cell_gate   = nn.Linear(ni + nh, nh)
        self.output_gate = nn.Linear(ni + nh, nh)

    def forward(self, input, state):
        h,c = state
        h = torch.cat([h, input], dim=1)
        forget = torch.sigmoid(self.forget_gate(h))
        c = c * forget
        inp = torch.sigmoid(self.input_gate(h))
        cell = torch.tanh(self.cell_gate(h))
        c = c + inp * cell
        out = torch.sigmoid(self.output_gate(h))
        h = out * torch.tanh(c)
        return h, (h,c)

btw, you can use a Linear layer on fattier tensors also. They operate over the last dim.

l = nn.Linear(2,4)
l(torch.rand(3,5,2)).shape
>>torch.Size([3, 5, 4])
3 Likes

Thanks,that’s really helpful.