Getting TypeError: 'LMModel7' object is not subscriptable during learn.save_encoder()

I’m trying to save the encoder for the following language model after training which has been defined here in the book:

class Dropout(Module):
    def __init__(self, p): self.p = p
    def forward(self, x):
        if not return x
        mask =*x.shape).bernoulli_(1-p)
        return x * mask.div_(1-p)

class LMModel7(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers, p=0):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
        self.drop = nn.Dropout(p)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h_o.weight = self.i_h.weight
        self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
    def forward(self, x):
        raw,h = self.rnn(self.i_h(x), self.h)
        out = self.drop(raw)
        self.h = [h_.detach() for h_ in h]
        return self.h_o(out),raw,out
    def reset(self): 
        for h in self.h: h.zero_()

learn = TextLearner(dls, LMModel7(len(vocab), 64, 2),
                    loss_func=CrossEntropyLossFlat(), metrics=accuracy)
learn.fit_one_cycle(1, 1e-2, wd=0.1)

But I’m getting the following error:

TypeError                                 Traceback (most recent call last)
<ipython-input-48-bb00513b4b05> in <module>
----> 1 learn.save_encoder('lang_model_enc')

/opt/conda/lib/python3.7/site-packages/fastai/text/ in save_encoder(self, file)
     93         "Save the encoder to `file` in the model directory"
     94         if rank_distrib(): return # don't save if child proc
---> 95         encoder = get_model(self.model)[0]
     96         if hasattr(encoder, 'module'): encoder = encoder.module
     97, join_path_file(file, self.path/self.model_dir, ext='.pth'))

TypeError: 'LMModel7' object is not subscriptable

Doing type(LMModel7) gives fastcore.meta.PrePostInitMeta.

Could someone please help me out?
How do I get rid of this and save the encoder so that I can load and use it later in some other file?

1 Like

I have same problem.
Perhaps… Should model be an instance of DistributedDataParallel?

get_model method

def get_model(model):
“Return the model maybe wrapped inside model.”
return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model