Since this is a Linear layer to start, it is most likely due to the dimensions of your input tensor. Since you are looking at images, you need the last dimension of each sample to be of size 128*128.

You can see what the shape of the input tensor is by grabbing a single batch and then looking at a single sample. How you get that batch depends on how it was created. Most likely this should work to get the batch: batch = next(iter(dis.train))

Once you have that, you just need to index into it to find the dimension size of one sample.

The layer nn,Flatten(start_dim=1) changes the dimensions of the input from (3,128,128) to (3,128*128).
My query is that shouldn’t I have to use nn.Flatten(start_dim=0) instead of nn.Flatten(start_dim=1) because the dimension of the input of the next layer is (3*128*128) instead of (3,128*128).