Beginner Q - Simple net vs cnn_learner for RGB images

I previously went through and built out my own learner as the homework for lesson 4 for the mnist problem. Yesterday I made another version to tackle RGB images to choose between either a black or a grizzly bear. I used the same mnist loss function, sigmoid and simple_net we built for chapter 4.

My original bear predictor using a third category had an error rate of .05 with the same data. With my simple version my batch accuracy peaks at 70%.

So, my question is, what are the factors that make the cnn_learner significantly better than a simple network? I’ve tinkered with the learning rate and the number of parameters. I ran into problems trying to add more layers to simple_net but I assume more layers would help. I’m not surprised that my model is significantly worse but I realized it’s a big gap in my knowledge that I can’t explain exactly why my model is so much worse.

The way i see it is number of parameters, and depth of network.
if you unroll an image like you would in a simple net, the number of parameters for a 20x20 image would be 400 times hidden layer size in just the first layer (imagining only grayscale). Now if you use a convolutional filter, the number of parameters is only 25 (5x5 filter). As you can see the number of parameters for each layer is drastically lower for cnns. Hence you can have deeper nets with fewer params. That’s one reason.

The other reason I can think of is data locality, atleast for images. pixels next to one another are treated as neighbours and worked on by the same filter in cnns, However when you roll out the image to a single array of pixels, the pixels are no longer close to one another. The network could still understand that some pixels should be neighbours, but it takes significantly more effort to do so.

I’m still a beginner so my answer should be taken with a grain of salt, but these are the intuitive explanations I can think of. There are many other factors that go into it I’m sure.

1 Like

A bear classifier in chapter 2 was built using a pretrained model. Which means it already head a lot of “knowledge” about images in general in it. It was trained on Imagenet dataset which consists of lots of images. So you only needed to train a “head” (last layers) of the model from scratch, the whole “body” (deeper layers) was already really good in detecting features of images. As for your simple net, you can improve it by adding some tweaks, which you’ll learn later in the book, and even further by utilizing some unsupervised pretraining. But still the pretrained model have seen a lots of data and learned a lot about images from it. That’s what makes it superior. And that’s why transfer learning works so great.

1 Like

I would imagine if I ran my model through enough epochs it could eventually make up for the lack of pretraining but it capped at .7031 accuracy and no matter how much I trained it the model could not pass that value. Is there any specific tweak that is likely to bridge the majority of that 25% accuracy gap between my simple model and the fast ai version?

Maybe not all the way, but you can always try deeper wider networks, more data. You might have to mess around with regularization. But the only things that can make a network better are generally architechture, time spent training, and data.

If your architechture is too simple, there’s only so much you can do.

1 Like

There is an important point to notice here. When training a network you need to monitor not only it’s validation metrics (accuracy in your case), but also training and validation loss. Also you can check accuracy on training set by calling learn.validate(ds_idx=0).
If models performs poorly on training set, it mean that you might have reached the models capacity and making deeper model and trying better architecture should help.
If your model performs notably better on training set then on validation set that is a sign of overfitting. Your model learns features of training set which do not generalize to new data. In this case you should try adding more data, using data augmentation techniques and trying regularization.
You will learn more about this stuff while progressing through the course and the book. There is a lot interesting stuff ahead

1 Like