How to increase the input of AWD_LSTM model from 1 to 3?

Hi all, I’d like to use AWD_LSTM (shown below) to train a language model. We have three inputs: mRNA sequence, structure, looptype, but this model can only take one kind of input (e.g. mRNA sequence). The length of the three kinds of inputs are same, which are all 107.

How to modify the model code so that it can take in 3 input?

My input shape of a 64 size batch is torch.Size([64, 107, 3]) , but this model can only take torch.Size([64, 107]) and returns torch.Size([64, 107, 32]) while 32 is the vocab size.

I want to train the sequence of three kinds of inputs in language model, instead of only one input (mRNA sequences). How to modify the code? Thanks.

model structure for a vocab size of 32:

SequentialRNN(
  (0): AWD_LSTM(
    (encoder): Embedding(32, 400, padding_idx=1)
    (encoder_dp): EmbeddingDropout(
      (emb): Embedding(32, 400, padding_idx=1)
    )
    (rnns): ModuleList(
      (0): WeightDropout(
        (module): LSTM(400, 1152, batch_first=True, bidirectional=True)
      )
      (1): WeightDropout(
        (module): LSTM(1152, 1152, batch_first=True)
      )
      (2): WeightDropout(
        (module): LSTM(1152, 400, batch_first=True)
      )
    )
    (input_dp): RNNDropout()
    (hidden_dps): ModuleList(
      (0): RNNDropout()
      (1): RNNDropout()
      (2): RNNDropout()
    )
  )
  (1): LinearDecoder(
    (decoder): Linear(in_features=400, out_features=32, bias=True)
    (output_dp): RNNDropout()
  )
)


class AWD_LSTM(Module):
    initrange=0.1
    def __init__(self, vocab_sz, emb_sz, n_hid, n_layers, pad_token=1, hidden_p=0.2, input_p=0.6, embed_p=0.1,
                 weight_p=0.5, bidir=False):
        store_attr('emb_sz,n_hid,n_layers,pad_token')
        self.bs = 1
        self.n_dir = 2 if bidir else 1
        self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)
        self.encoder_dp = EmbeddingDropout(self.encoder, embed_p)
        self.rnns = nn.ModuleList([self._one_rnn(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir,
                                                 bidir, weight_p, l) for l in range(n_layers)])
        self.encoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.input_dp = RNNDropout(input_p)
        self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])
        self.reset()

    def forward(self, inp, from_embeds=False):
        bs,sl = inp.shape[:2] if from_embeds else inp.shape
        if bs!=self.bs: self._change_hidden(bs)

        output = self.input_dp(inp if from_embeds else self.encoder_dp(inp))
        new_hidden = []
        for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):
            output, new_h = rnn(output, self.hidden[l])
            new_hidden.append(new_h)
            if l != self.n_layers - 1: output = hid_dp(output)
        self.hidden = to_detach(new_hidden, cpu=False, gather=False)
        return output

    def _change_hidden(self, bs):
        self.hidden = [self._change_one_hidden(l, bs) for l in range(self.n_layers)]
        self.bs = bs

    def _one_rnn(self, n_in, n_out, bidir, weight_p, l):
        "Return one of the inner rnn"
        rnn = nn.LSTM(n_in, n_out, 1, batch_first=True, bidirectional=bidir)
        return WeightDropout(rnn, weight_p)

    def _one_hidden(self, l):
        "Return one hidden state"
        nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz) // self.n_dir
        return (one_param(self).new_zeros(self.n_dir, self.bs, nh), one_param(self).new_zeros(self.n_dir, self.bs, nh))

    def _change_one_hidden(self, l, bs):
        if self.bs < bs:
            nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz) // self.n_dir
            return tuple(torch.cat([h, h.new_zeros(self.n_dir, bs-self.bs, nh)], dim=1) for h in self.hidden[l])
        if self.bs > bs: return (self.hidden[l][0][:,:bs].contiguous(), self.hidden[l][1][:,:bs].contiguous())
        return self.hidden[l]

    def reset(self):
        "Reset the hidden states"
        [r.reset() for r in self.rnns if hasattr(r, 'reset')]
        self.hidden = [self._one_hidden(l) for l in range(self.n_layers)]

@sgugger
@muellerzr
@orendar

1 Like

I shared a notebook on how to approach this Kaggle competition with fastai: https://www.kaggle.com/thedrcat/openvaccine-lstm-fastai/ I’m not using AWD_LSTM but you should be able to modify AWD_LSTM based on my code example. Good luck!

5 Likes

Hi Darek, so nice to be in touch with you! I’ve been learning your code on Kaggle since weeks ago. Thank you so much for sharing your code!

My idea is, instead of directly getting the 5 outputs through the LSTM model as shown in your code, I want the model to learn the language of mRNA, structure and loop first, so I can save the encoder and load to the final model to fine tune. I checked the output of the awd-lstm language model it looks like the input tensor is [batch-sz, sequence_length] and output tensor is [batch-sz, sequence_length, vocab_sz]. So I’m curious how to get multiple input in and get what size of output out so that the language learner can work.

1 Like

I understand now - I’m actually trying to do the same myself, but haven’t solved it yet. I don’t think the default fastai language model learner will work here because it is single-direction, and I think we need bi-directional model to capture the entire sequence. That means we need to do something like a masked-language modeling objective, changes to loss function etc. I may be wrong though - if you want to try the default language model learner, try to modify my code based on this chapter: https://github.com/fastai/fastbook/blob/master/12_nlp_dive.ipynb

Could you get by with ensembling a backwards and forwards model together?

1 Like

I’m no expert on mRNA, but my guess is that this wouldn’t work - the reason is that mRNA is a graph, for example in sequence AAABBBBCCCC when I predict reactivity of the first A, I need to know if it is paired, and what that pair is.

1 Like

Not sure if this is going to work, but below is my language model learner code. Right now I’m statically masking certain words in input, ideally that should be a transform. I also need to change loss function to only consider the masked tokens. If it works, the weights should be transferrable to the mRNA prediction model.

class OVModel(Module):
    def __init__(self, vocab1_sz, vocab2_sz, vocab3_sz, emb_sz, n_hidden, n_layers, p):
        self.i_h1 = nn.Embedding(vocab1_sz, emb_sz)
        self.i_h2 = nn.Embedding(vocab2_sz, emb_sz)
        self.i_h3 = nn.Embedding(vocab3_sz, emb_sz)
        self.rnn = nn.LSTM(emb_sz*3+3, n_hidden, n_layers, batch_first=True, bidirectional=True)
        self.drop = nn.Dropout(p)
        self.h_o1 = nn.Linear(n_hidden*2, vocab1_sz)
        self.h_o2 = nn.Linear(n_hidden*2, vocab2_sz)     
        self.h_o3 = nn.Linear(n_hidden*2, vocab3_sz)     
        self.h = [torch.zeros(n_layers*2, BS, n_hidden) for _ in range(2)]
        
    def forward(self, x):
        e1 = self.i_h1(x[:,:,0].long())
        e2 = self.i_h2(x[:,:,1].long())
        e3 = self.i_h3(x[:,:,2].long())
        bp = x[:,:,3:]
        e = torch.cat((e1, e2, e3, bp), dim=2)
        raw,h = self.rnn(e, self.h)
        do = self.drop(raw)
        out1 = self.h_o1(do)
        out2 = self.h_o2(do)
        out3 = self.h_o3(do)
        self.h = [h_.detach() for h_ in h]
        return out1, out2, out3, raw, do
    
    def reset(self): 
        for h in self.h: h.zero_()

loss_fn = CrossEntropyLossFlat()

def loss_func(out, targ):
    out1 = out[0] # removing the hidden output
    out2 = out[1]
    out3 = out[2]
    targ1 = targ[:,:,0]
    targ2 = targ[:,:,1]
    targ3 = targ[:,:,2]
    l1 = loss_fn(out1, targ1)
    l2 = loss_fn(out2, targ2)
    l3 = loss_fn(out3, targ3)
    return l1+l2+l3
1 Like

Thank you Darek!

For future reference, here’s my current attempt at masking transform. By changing the target to -100, it should exclude it from loss calculation.

class MaskTransform(ItemTransform):
    def __init__(self, p): self.p = p
    def encodes(self, x):
        inp,tgt = x
        shp = inp.shape[0]
        a = torch.ones(shp,1) * self.p
        mask = torch.bernoulli(a)
        inpmask = torch.ones(shp,1) * torch.Tensor([4, 3, 7, 0, 0, 0])
        tgtmask = (torch.ones(shp,1) * torch.Tensor([-100, -100, -100])).type(torch.LongTensor)
        newinp = torch.where(mask == 1, inpmask, inp)
        newtgt = torch.where(mask != 1, tgtmask, tgt)
        return newinp, newtgt
2 Likes