DeepSpeech Architecture Help Requested

I am currently trying to implement DeepSpeech using fastai and I think I am pretty close except one slight problem. It isn’t training. So I want to get some help from the community to hopefully figure out my issue. My current questionable part is with the RNN which I have used LSTM, but I would probably prefer to use a normal RNN since the paper specifically calls it out because if having less memory requirements.

Here is a link to the paper: http://arxiv.org/abs/1412.5567

Here is a link to my github repo: https://github.com/kevinbird15/DeepSpeech-Fastai

A few things I’ve learned going through this. You want to feed spectrograms into this model and have a list of characters spit out from the model. More specifically, you want to slice the spectrogram vertically and feed that into the model along with something referred to in the paper as context or C. This is basically just a number of frames before and after the frame that is being predicted. So these small slices of the spectrogram get fed into the model and a prediction of what character is being spoke is output as the prediction. This is then put into the ctc loss function and a loss is determined.

My current issue/question I’m having is converting the model from the paper into actual code and that’s where I’m hoping the fastai community can help me!

This is what I have as my deepspeech architecture:

class DeepSpeech(nn.Module):
    def __init__(self, context=5, bs=64):
        super(DeepSpeech, self).__init__()
        self.bs = bs
        self.context = context
        self.h = None#(torch.zeros((2,5,2048)).cuda(),torch.zeros((2,5,2048)).cuda())#None
        self.flatten = nn.Flatten()#lambda x: torch.reshape(x,(-1,1,2432))
        self.relu = nn.ReLU()
        self.h1 = nn.Linear(128*2*self.context+128,2048)
        self.h2 = nn.Linear(2048,2048)
        self.h3 = nn.Linear(2048,2048)
        self.h4 = nn.LSTM(2048,2048,bidirectional=False,batch_first=True)
        #self.h4 = nn.RNN(2048,2048, nonlinearity="relu",bidirectional=True)
        self.h5 = nn.Linear(2048,29) #ct ∈ {a,b,c, . . . , z, space, apostrophe, blank}
        self.softmax = nn.LogSoftmax(dim=2)

    def forward(self, x):
        x = nn.Flatten(-2,-1)(x)#torch.stack(x,dim=1)
        x = self.h1(x)#.clamp(min=0, max=20)
        x = self.relu(x)
        x = self.h2(x)#.clamp(min=0, max=20)
        x = self.relu(x)
        x = self.h3(x)#.clamp(min=0, max=20)
        x = self.relu(x)
        if self.h is None:
            ;
        elif self.h[0].shape[1]>x.size(0):
            self.h=tuple([each[:,:x.size(0),:] for each in model.h])
        elif self.h[0].shape[1]<x.size(0):
            self.h=None#tuple([each.expand(-1,x.size(0),-1) for each in model.h])
        x,h = self.h4(x, self.h)
        self.h = to_detach(h, cpu=False)
        x = x.view(-1,166,1,2048)
        x = x.sum(dim=2)
        x = self.h5(x)#.clamp(min=0, max=20)
        x = self.softmax(x)
        return x

The part I would really like somebody to tear apart and crush is my LSTM (self.h4) piece which should be bidirectional=True to match the paper, but I figured this was at least a starting point!

Basically if I don’t do a bunch of weird stuff with my hidden piece (self.h), I get mismatches when my batch size doesn’t make each batch completely full so that is where my hacky work-around comes from. I would also ideally like to get back to .clamp(min=0, max=20) to match the paper better while still being readable.

My CTC loss has been wrong a lot of times already, but I feel ok(ish) about it at this point thanks to a ton of help from @scart97. If there are other questions or if anybody wants to see anything else that I’ve done, let me know.

1 Like

Is there a reason in particular why you want to keep the LSTM hidden states from one batch to the other?

That’s how I thought they were supposed to work, but no, I don’t have any specific attachment to carrying the hidden state over from the previous batch. Maybe I will try dropping that piece and just always give it None so it initializes to 0?

It kind of depends on the task, but usually you don’t want to carry your hidden states from one batch to the next unless they’re somehow related. Think of it like this:

If your batches contain the audio for different sentences, like this:

[
“Remind me to take out the spaghetti in 10 minutes”,
“When was Albert Einstein born?”
]

Then the state from the first sentence doesn’t really help the network understand the second one.
There are specific use cases where a network is used to model very long sequences (like XLNet) that don’t fit in memory, so they’re split up time-wise and the hidden states from the previous snippet are used to help with context:

[
“Albert Einstein was a German-born theoretical physicis who developed the theory of relativity…”,
“Near the beginning of his career, Einstein thought that Newtonian mechanics was no longer enough…”
]

To which of these two is your task more similar? If you’re not sure it’s probably the former

2 Likes

Yeah, definitely the former. That makes sense. I am still not getting great results, but I understand what you’re saying with the hidden state stuff.

So if it was all a single audio file that was long, then you would want the hidden state to be kept between batches.

And even then you only want to keep the state for specific network architectures, usually the ones that aim at modeling very long-range dependencies.

And another tip: I understand that for now you’re just trying to reproduce that paper, but it might be interesting to try and apply the transformer architecture to this problem as it made RNNs obsolete almost everywhere in the past couple years.

Happy coding!

Thanks, I am definitely planning on using this as a jumping off point to do a lot of different tests into that type of improvement Alex! Most of this will become a lot easier to do with V2 as well once we get fastai_audio added into that as well, but I at least want to have something to build the v2 piece of Audio ASR. That’s why I’m trying to get this working at the moment is so that I can convert it into v2 and hopefully get it working in that context as well because I think being able to do ASR out of the box with v2 would be pretty great and would make different tests like the transformer thing you mentioned a lot more feasible.