I’m seeking an intuitive explanation of one of the features of AWD-LSTM.

Notebook 12a includes a Python version of AWD-LSTM, in which Jeremy splits the input into 4 pieces and runs each one through a gate in the forward method of the LSTMCell Class. That seems to match the equations given. However, I noticed that in Colah’s Understanding LSTM Networks equations, the entire input seems to be used for each of the gates.

Splitting the input into 4 pieces doesn’t make sense to me intuitively, since I don’t see how you would decide which piece goes through which gate. Is there a trick I’m missing, like randomizing the split or perhaps pre-copying the matrix so that when you split it, you are back to a single copy for each gate? I don’t quite see that in the code.

class LSTMCell(nn.Module):
def __init__(self, ni, nh):
super().__init__()
self.ih = nn.Linear(ni,4*nh)
self.hh = nn.Linear(nh,4*nh)
def forward(self, input, state):
h,c = state
#One big multiplication for all the gates is better than 4 smaller ones
gates = (self.ih(input) + self.hh(h)).chunk(4, 1)
ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])
cellgate = gates[3].tanh()
c = (forgetgate*c) + (ingate*cellgate)
h = outgate * c.tanh()
return h, (h,c)

Are you referring to this code? I think he’s initializing his layers that way in init self.ih and self.hh are initialized as 4 layered tensors. Will go through the video again once I’m there to see if this explanation is valid though. That’s my understanding for now

So he is chunking the product but he’s actually doing a big multiplication . self.ih and self.hh are basically are a bigger tensor than initialized . Generally you would copy the self.ih 4 times and do the multiplication however if you initialize it as 4 layers Jeremy says its better. For e.g say [2] is a tensor . Normally you would copy [2] 4 different times for 4 different multiplications. But initialize it like this [2,2,2,2] and do one big multiplication with 4 different numbers like this [1,2,3,4]. This Jeremy reckons is better as its less memory intensive. It’s a really simple example that I’ve chosen. Hope it’s clear enough .