CNN + LSTM: How to feed images into LSTM model?

Hi, I am trying to create a similar model as LSTM RNN from lesson 8 (course v4) but instead of using text input data, I want to feed in a sequence of images. My input data is grayscale images with the batch shape of: torch.Size([512, 1, 1, 128]) and I want to classify them (there are only two classes). I’m basically trying to combine CNN with LSTM. Here’s the code to build an LSTM from scratch from lesson 8 - Any suggestions on how to modify it so that it can process images?

class LMModel6(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = [h_.detach() for h_ in h]
        return self.h_o(res)
    
    def reset(self): 
        for h in self.h: h.zero_()

I see that self.i_h = nn.Embedding(vocab_sz, n_hidden) has to be changed and I guess that a ResBlock has to be put here instead but I don’t seem to find any answer on how to do it.
Thanks in advance folks!

1 Like

:confused:
This is not the shape of a batch of images.

Are you trying to run a sequence of images through an RNN and classify the sequence?

Or are you trying to treat the pixels of a single image as a sequence and classify that sequence? :slightly_smiling_face:

Well, actually I was wrong because I apply normalization to grayscale images (1 channel) making them RGB images with 3 channels:

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   get_x=GetX,
                   get_y=GetY,
                   batch_tfms=Normalize.from_stats(*imagenet_stats))

So then yeah, my batch shape is torch.Size([512, 3, 1, 128]) where 3 is the number of channels. I know, the shape is strange because every image is 1x128 pixels, it looks like this:
image

This small picture is a spectrogram of 0.01 seconds of an audio file and is part of a huge long spectrogram (torch.Size([398709, 1, 128])) that was made for the whole audio. So answering your question…

This seems closer to the answer of what I’m trying to do. :smiley: For each small picture I’m trying to predict whether it belongs to class 0 or 1.

This is not a batch of images either.
A batch of 8 images on greyscale (1ch) of size (128x128) would have shape (8, 1, 128,128). So a sequence of images would have an extra dim to account for the time dimesion. In my case is

(bs, t, ch, h, w)

You have various different approaches to this problem, you could encode images independently before feeding the recurrent layer, to a latent feature space, let’s say of size 128.
So an image, would be a vector of shape (1,128). Then you could feed a traditional nn.LSTM layer (not the one above).
Another approach, is to modify jeremy’s LSTM to accept sequence of images, so integrate the convolutional layer on the recurrent part, and you would obtain a Convolutional Recurrent layer, a.k.a ConvLSTM.

Have been playing with this type of problems for a while, check:
https://github.com/tcapelle/action_recognition for a video classification using CNN -> LSTM approach
https://github.com/tcapelle/moving_mnist/ for a CNN -> ConvLSTM/GRU approach (also transformer, PhyCell and seq2seq)

Both repos are ported to fastai V2.

2 Likes

Hi Pavlo,

Thomas’s code gives you (and me) some great examples and starting points. At the same time. I think it’s important that you have a clear concept of what you are trying to do. Please excuse me if I make any wrong assumptions about what you already understand.

RNNs are typically applied to a time dimension. The essence is that what has gone before is the cause of what comes after. The RNN learns the causal relationship in its hidden state and gives you a conclusion -a prediction or a class- after seeing the whole sequence. (You could in theory apply an RNN to the frequency dimension but there is not much of a cause-and-effect relationship, and there are more effective ways to extract its features.)

It looks like you want to classify time sequences (length 398709) made of [1,128] elements, each of which you are calling a image. If this is right, then you running a sequence of “images” through an RNN to classify the sequence, not treating the pixels of a single image as a sequence.

To get some clarity, stop calling these [1,128] things images. Images are inherently two dimensional, with structure along both dimensions. Instead, think of them as vectors of 128 features arranged in time sequence. Each single element of your training set is therefore 128 features by 398709 time steps. Channels and embeddings are irrelevant here.

As a roadmap, I suggest starting with the very simplest LSTM. It takes 128 raw features and processes the whole time sequence, in batches. It will probably give you a decent result.

But if this simple model does not train well enough (and I think this is what you are trying to do), then you could pre-process the elements of the sequences with CNN. That is, apply a trainable CNN to each element (128 features) of the time sequence to extract features. Then feed these extracted features (a time series of feature vectors) to the LSTM. I’m pretty sure you can do all this in batches without loops.

There are many, many approaches to the problem you describe. Lots have been discussed in these forums: give the entire spectrogram (truly an image!) to a resnet classifier; various combinations of CNN and RNNs; CNNs along both the frequency and time dimensions; ROCKET, self-attention and transformers. You will have to experiment to discover what works. But I recommend starting with the simplest model and gradually adding complexity based on its performance and your understanding.

HTH you to get started. Good luck! :slightly_smiling_face:

Thanks a ton to both of you! @Pomo I have the concept more or less clear in mind, but I’m very glad you point those things out to me, I am indeed very new to the field and know very little about working with time series. An RNN does make sense here because I’m classifying speech where it’s helpful to know what came before to be able to predict what comes next. In fact, this is how speech recognition systems normally work. Taking into account what you both mentioned:

I will indeed try implementing this approach and feeding the raw 128 features as time series directly into the LSTM. I might have some questions along the way though if you would allow me :slight_smile:

I put a simple example to get you started, you can replace the head of the model depending on the task:

from fastai.vision.all import *

@delegates(create_cnn_model)
class Encoder(Module):
    def __init__(self, arch=resnet34, n_in=3, **kwargs):
        "Encoder based on resnet"
        model = create_cnn_model(arch, n_out=1, n_in=n_in, pretrained=True, **kwargs)
        self.body = model[0]
        self.pool = nn.Sequential(*(model[1][0:2]))

    def forward(self, x):
        return self.pool(self.body(x))

class CNN_LSTM(Module):
    def __init__(self, encoder, hidden_dim, num_layers=1):
        nf = num_features_model(nn.Sequential(*encoder.body.children())) * 2
        self.encoder = encoder
        self.lstm = nn.LSTM(nf, hidden_dim, num_layers=num_layers)
        self.head = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        bs, seq_len, c, h, w = x.shape
        x = x.view(bs * seq_len, c, h, w)
        x = self.encoder(x)
        x = x.view(bs, seq_len, -1)
        x, (h,c) = self.lstm(x)
        return self.head(x)

encoder = Encoder()
model = CNN_LSTM(encoder, 128)

bs, seq_len=8, 5
x = torch.rand(bs, seq_len, 3, 128, 128)
model(x).shape
>> torch.Size([8, 5, 1])

This looks great, thank you so much @tcapelle! I’m trying now different approaches and will post what worked for me best!