Flatten layer of PyTorch

I am trying to build a simple classification network by PyTorch, but I do not how to flatten the convolution layer.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        main = nn.Sequential()
        self._conv_block(main, 'conv_0', 3, 6, 5)
        main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
        self._conv_block(main, 'conv_1', 6, 16, 3)
        main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2)) 
        #how could I flatten the convolution layer?
        
        self._main = main
        
    def forward(self, x):
        return x
        
    def _conv_block(main, name, inp_filter_size, out_filter_size, kernal_size):        
        main.add_module('{}-{}.{}.conv'.format(name, inp_filter_size, out_filter_size), 
                        nn.Conv2d(inp_filter_size, out_filter_size, kernal_size, 1, 1))
        main.add_module('{}-{}.batchnorm'.format(name, out_filter_size), nn.BatchNorm2d(out_filter_size))
        main.add_module('{}-{}.relu'.format(name, out_filter_size), nn.ReLU())

Thanks

You can use something like this .

x.view(-1, size of the convolution Example :32*32*64)

You can check for a full example here.

Thanks, I have seen this before, but I cannot apply this solution directly on my model, because my model is build by a sequential container.

main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2)) 
main.add_module('flatten', make_it_flatten)

my problem what should I put in the “make_it_flatten”?

Can you share the code in the form of a gist , so that it is easy to help. I am not sure if pytorch has something like flatten. Even if you have a sequential container , you can use view to change the shape of the tensor outside the container. All Sequential is doing here , is applying a bunch of computations on input tensor and generates output tensor.

No problem, I give another solution a shot

class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        main = nn.Sequential()
        
        self._conv_block(main, 'conv_0', 3, 6, 5)
        main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
        
        self._conv_block(main, 'conv_1', 6, 16, 3)        
        main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2))
        
        main.add_module('flatten', Flatten())
        self._linear_block(main, 'linear_0', 16*3*3, 120)        
        self._linear_block(main, 'linear_1', 120, 84)
        main.add_module('linear_2-84-10.linear', nn.Linear(84, 10))
        
        self._main = main
        
    def forward(self, x):
        for module in self._modules.values():
            x = module(x)
        return x
        
    def _conv_block(self, main, name, inp_filter_size, out_filter_size, kernal_size):        
        main.add_module('{}-{}.{}.conv'.format(name, inp_filter_size, out_filter_size), 
                        nn.Conv2d(inp_filter_size, out_filter_size, kernal_size, 1, 1))
        main.add_module('{}-{}.batchnorm'.format(name, out_filter_size), nn.BatchNorm2d(out_filter_size))
        main.add_module('{}-{}.relu'.format(name, out_filter_size), nn.ReLU())                
        
    def _linear_block(self, main, name, inp_filter_size, out_filter_size):
        main.add_module('{}-{}.{}.linear'.format(name, inp_filter_size, out_filter_size), 
                        nn.Linear(inp_filter_size, out_filter_size))
        main.add_module('{}-{}'.format(name, out_filter_size), nn.ReLU())

This give me error messages

RuntimeError: size mismatch, m1: [4 x 784], m2: [144 x 120] at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1237

If I change Flatten to

class Flatten(nn.Module):
    def forward(self, x):        
        x = x.view(-1, 16*3*3)
        return x

It give me error messages

RuntimeError: size ‘[-1 x 144]’ is invalid for input of with 3136 elements at /b/wheel/pytorch-src/torch/lib/TH/THStorage.c:55

All of the codes are place at pastebin, 90 lines

I found the answer at PyTorch forum, my input size is incorrect. Hope that in the future pytorch can calculate input size automatically

1 Like