Training 4.5 million parameters takes 30 mins per epoch?

Hi, I’m still quite new to deep learning. I’m trying to reproduce a paper on light-fields: Stereo Magnification: Learning View Synthesis using Multiplane Images. The paper has the code available but it’s using tensorflow so I thought it would be a good project as I’m also interested in light-fields/view synthesis.

This is my COLAB.

I’ve been quite successful in the sense that I’ve got a colab that trains the model and it seems to work to a certain degree.

The problem is that training each epoch is taking some 30 minutes, and I wonder if you guys have some ideas on how to improve the speed of the training (maybe a smaller model?). The model is just as defined in the paper, and it turns out to be a 4.5 million parameters, with 18 convolution layers.

I’ve actually tried to create a smaller model with just 7 layers and 1.5 million parameters but it’s still as slow. Furthermore, for some reason I’m only able to use a batch size of 1, as bigger batch sizes trigger out-of-memory issues.

Notice that, effectively, the input of the model is 11 images and the output is 10 images: the Multi-Plane Image which can be used to synthesize new images from different points of view simply using alpha-composition. However the model is only about 18MB, and 10 to 11 images at a resolution of 220x220 is not that much, or is it?

Question: is your data in your Google Drive?

I’m using google-drive just to load/save the model so that I can keep training it if I disconnect/reconnect to a runtime.

The Spaces dataset (10GB) I’m using lives in github, and it’s cloned into colab (on the default /content folder) on the first command.

I am defining the model with skip connections just using PyTorch. Am I doing it wrong? This is the code:

class StereoMagnificationModel(Module):
  def __init__(self, ngf = 33):
    self.ngf = ngf
    self.cnv1_1 = ConvLayer(ngf,ngf, ks=3, stride=1)                                  # 224
    self.cnv1_2 = ConvLayer(ngf,ngf*2, ks=3, stride=2)                                # 112

    self.cnv2_1 = ConvLayer(ngf*2,ngf*2, ks=3, stride=1)                              # 112
    self.cnv2_2 = ConvLayer(ngf*2,ngf*4, ks=3, stride=2)                              # 56

    self.cnv3_1 = ConvLayer(ngf*4,ngf*4, ks=3, stride=1)                              # 56
    self.cnv3_2 = ConvLayer(ngf*4,ngf*4, ks=3, stride=1)                              # 56
    self.cnv3_3 = ConvLayer(ngf*4,ngf*8, ks=3, stride=2)                              # 28

    self.cnv4_1 = ConvLayer(ngf*8,ngf*8, ks=3, stride=1, dilation=2, padding=2)       # 28
    self.cnv4_2 = ConvLayer(ngf*8,ngf*8, ks=3, stride=1, dilation=2, padding=2)       # 28
    self.cnv4_3 = ConvLayer(ngf*8,ngf*8, ks=3, stride=1, dilation=2, padding=2)       # 28

    self.cnv5_1 = ConvLayer(ngf*16,ngf*4, ks=4, stride=2, transpose=True, padding=1)  # 56
    self.cnv5_2 = ConvLayer(ngf*4,ngf*4, ks=3, stride=1)                              # 56
    self.cnv5_3 = ConvLayer(ngf*4,ngf*4, ks=3, stride=1)                              # 56

    self.cnv6_1 = ConvLayer(ngf*8,ngf*2, ks=4, stride=2, transpose=True, padding=1)   # 112
    self.cnv6_2 = ConvLayer(ngf*2,ngf*2, ks=3, stride=1)                              # 112

    self.cnv7_1 = ConvLayer(ngf*4,ngf, ks=4, stride=2, transpose=True, padding=1)     # 224
    self.cnv7_2 = ConvLayer(ngf,ngf, ks=3, stride=1)                                  # 224

    self.cnv8_1 = ConvLayer(ngf,ngf, ks=1, stride=1, norm_type=None, act_cls=nn.Tanh) # 224

  def forward(self, x):
    out_cnv1_1 = self.cnv1_1(x)
    out_cnv1_2 = self.cnv1_2(out_cnv1_1)

    out_cnv2_1 = self.cnv2_1(out_cnv1_2)
    out_cnv2_2 = self.cnv2_2(out_cnv2_1)

    out_cnv3_1 = self.cnv3_1(out_cnv2_2)
    out_cnv3_2 = self.cnv3_2(out_cnv3_1)
    out_cnv3_3 = self.cnv3_3(out_cnv3_2)

    out_cnv4_1 = self.cnv4_1(out_cnv3_3)
    out_cnv4_2 = self.cnv4_2(out_cnv4_1)
    out_cnv4_3 = self.cnv4_3(out_cnv4_2)

    # add skip connection
    in_cnv5_1 = torch.cat([out_cnv4_3, out_cnv3_3],1)

    out_cnv5_1 = self.cnv5_1(in_cnv5_1)
    out_cnv5_2 = self.cnv5_2(out_cnv5_1)
    out_cnv5_3 = self.cnv5_3(out_cnv5_2)

    # add skip connection
    in_cnv6_1 = torch.cat([out_cnv5_3, out_cnv2_2],1)

    out_cnv6_1 = self.cnv6_1(in_cnv6_1)
    out_cnv6_2 = self.cnv6_2(out_cnv6_1)

    # add skip connection
    in_cnv7_1 = torch.cat([out_cnv6_2, out_cnv1_2],1)

    out_cnv7_1 = self.cnv7_1(in_cnv7_1)
    out_cnv7_2 = self.cnv7_2(out_cnv7_1)

    out_cnv8_1 = self.cnv8_1(out_cnv7_2)

    return out_cnv8_1

Your model seems fine. To speed up, you should copy your dataset from drive to the local colab machine as accessing files over network (your drive folder) tends to be slow.

Also, you can you mixed precision training (`learn.to_fp16 or learn.to_native_fp16) and then increase your batch size.

Thanks Victor, using learn = learn.to_native_fp16() does help and I’m able now to at least use a batch size 2, so it’s training almost twice as fast.

My dataset is in my local colab machine/hard drive already. As I said I’m only using gdrive to store the model after training, not the dataset.

So I’m still wondering why it’s so slow or if this is normal at all, is it because it’s really that computationally expensive? Maybe it’s the way I’m loading/processing the dataset or the loss function? What’s a normal epoch training time for a model with 18 layers, 33 input images and 4.5 million parameters in Colab (or similarly sized machines)?

Parameters counts != FLOPS (how many computations you need). Your model seems an encoder-decoder architecture with the same output size as the input. So, it’s definitely memory intensive.

You can:

  • Reduce image resolution.
  • Use inplace activations. ConvLayer use ReLU() instead of ReLU(inplace=True).
  • Convert model to torchscript (torch.jit.script(StereoMagnificationModel())). Sometimes, it’s more efficient and trains faster.
  • Check your post processing. Profile memory usage during backwards. If it’s high, try to use inplace operations when possible. Also:
    • mpi_from_net_output seems complicated. Also, it should create tensor directly on the right device.
    • mpi_render_view_torch -> i don’t see it’s definition.

If you are willing to change the model:

  • Remove some batch norm layers. They are very memory intensive. Also, using a batch size < 12 usually doesn’t work well. Try to replace them with some other batch normalization layers like instance normalization.
  • Reduce model channels.

Finally, use a GPU with more memory. From paper: “Training takes about one week on a Tesla P100 GPU.”. They have 12 or 16GB. Also, they said “We train the network using the ADAM solver [Kingmaand Ba 2014] for 600K iterations with learning rate0.0002,β1=0.9,β2=0.999,and batch size1.” So, I doubt that they used batch norm layer.

@vferrer Thanks for the help! I’ve tried with instance normalization and now I can do an Epoch in about 15 minutes, which is a great improvement indeed! I’ll refactor it to also use the right device directly and try to profile memory usage (I don’t know how to do that yet).

In the meantime I’ve been adding perceptual loss using vgg19 as in the paper, I’ll be sharing my progress soon (right now I’m fighting Python’s pickle protocol).

On the post-processing side: mpi_render_view_torch is defined in my github [https://github.com/Findeton/mpi_vision](https://github.com/Findeton/mpi_vision/blob/main/utils.py#L267). There’s quite a bit of code there but it’s just a translation of the original code to pytorch and I’ve tested it a bit already. There’s definitely room to optimise it but the post-processing is necessary as the loss function needs to compare two images, but we want the output to be a stack of images (a Multi-Plane Image), which then we use to create novel views from different points of view.

I tried your code.
The dataloader is fast, that is not the problem.
I could run up to bs=8 without issues.
It is the cloud instance probably that is slow.

So right now I’m having trouble making the perceptual loss function using vgg 16. It’s supposed to be easy-ish, but when I train the model the validation loss stays exactly the same, so it looks like it’s not learning at all. I’ve basically based the code on this public gist. Specifically, my code for the loss is this:

class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
        for bl in blocks:
            for p in bl.parameters():
                p.requires_grad = False
            
        self.blocks = torch.nn.ModuleList(blocks)
        self.transform = torch.nn.functional.interpolate

        self.mean_const = torch.tensor([0.485, 0.456, 0.406], device=device).view(1,3,1,1)
        self.std_const = torch.tensor([0.229, 0.224, 0.225], device=device).view(1,3,1,1)

        self.resize = resize

    def forward(self, mpi_pred, dep):
        rgba_layers = mpi_from_net_output(mpi_pred, dep)
        rel_pose = torch.matmul(dep['tgt_img_cfw'], dep['ref_img_wfc'])
        
        input  = mpi_render_view_torch(rgba_layers, rel_pose, dep['mpi_planes'][0], dep['intrinsics'])
        target = dep['tgt_img']

        input = input.permute(0, 3, 1, 2)
        target = target.permute(0, 3, 1, 2)
        
        input = (input-self.mean_const) / self.std_const
        target = (target-self.mean_const) / self.std_const
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        x = input
        y = target
        loss = torch.nn.functional.l1_loss(x, y)
        for block in self.blocks:
            x = block(x)
            y = block(y)
            loss += torch.nn.functional.l1_loss(x, y)
        return loss

And I have another COLAB notebook if you want to try training it (to no avail). What’s the source of the problem?

I don’t know. Try to use some simpler loss and see if it works. VGGPerceptualLoss is quite complicated.