Segmentation: Batch Size and Training Size and OOM

Hi Everyone -

I’m working on developing a Fully-Convolutional-Network for a segmentation problem using satellite imagery. I’ve mainly been playing around with variations of the Tiramisu architecture and am starting to look more at U-Net (due to problems mentioned below).

The Tiramisu paper uses only a few hundred images in their training set. The DSTL Kaggle competition had only 25 images in their training set. I have about 46,000.

With my Tiramisu inspired networks I get decent preliminary results on a small handful of images but sadly I get OOM errors with batch sizes above 8 or so depending on the details of the network. In the Tiramisu paper they are using a batch size of 3 and 5 - for a few hundred images this seems fine but I had expected to use a much larger batch size to allow me to train on tens of thousands of images.

My questions:

  1. How do these networks train so well on such a small amount of imagery? One thought I had was that each individual pixel acts as separate training data - so passing through a single image is similar to having,say, 256x256 images in a CNN. One worry I have is that they are working so well because there isn’t substantial variation in the images they are looking at.

  2. Are there tricks that I may be missing that would allow me to have a larger batch size without hitting the OOM error?

  3. Assuming no – I see a few ways forward - which do people recommend?

    • By inspection (or randomness) pick a small few hundred image subset of the full 46,000 set, train the network and hope it generalizes well to the rest of the images
    • Try and train 30,000 images (say) with a batch size of 6ish.
    • Use smaller images - allowing me to increase the batch size but simultaneously multiplying the amount of images available
    • Looking into different architectures that wouldn’t be so memory intensive. (Recommendations?)

Thoughts on any of those three points - or any general advice greatly appreciated.

All of those may help. Here is some insight into where the memory goes and a possible but not easy solution. Have not tried it myself.

1 Like

Thanks for the reference. At first glance it looks useful. Digging in now!

Pytorch 0.4 has a new feature that dramatically decreases memory use:

do you know if this slows down training (for fixed batch-size)?

Apparently it slows it down a little in theory, but doesn’t seem to impact it in practice.

oh great - this may get me off of keras :slight_smile: