I’m working on building a new architecture that I’m hoping doesn’t completely suck, but I am a little bit confused with what code I should be looking at for inspiration. My first place I’m looking is UNet, but I am having a hard time understanding exactly what is happening here so I want to explain my understanding and see if I’m on the correct track.
The first thing to do when building a new architecture is to make a class that inherits nn.Module (from pytorch)
After this, make a _init_. This is where you will put anything special about your model (how many layers, how many classes, etc).
Then the forward method is created to actually use the new model architecture. This is where the actual architecture is created and the information of what the inputs will be and how the different layers will be passed to each other so in KBNet below, I init it to 1 layer and 2 classes and then I am using that information in the forward (not layers I guess, but classes) to determine how many outputs our model has. Am I doing this correctly and I guess the other question I have on this is how would I put something like resnet as my forward so if I had something that I wanted to pass from one resnet layer to another would that be possible?
class KBNet(nn.Module): def __init__(self, layers=1, n_classes=2): super().__init__() self.layers = layers self.conv2d = nn.Conv2d(32,64,1) self.n_classes = n_classes def forward(self, x, y): x = self.conv2d(x) y = self.conv2d(y) xy = torch.cat([x,y],dim=1) xy = torch.sigmoid(xy, self.n_classes) return(xy)
What I would expect this to be is a model that has two conv layers that both take a 32 channel input and outputs a 64 channel input. Then concatenate those layers together and output a prediction based on that, but when I create a new instance of this model with
new = KBNet(n_classes=5), I get this instead:
KBNet( (conv2d): Conv2d (32, 64, kernel_size=(1, 1), stride=(1, 1)) )
Definitely planning on digging into this more tomorrow, but just curious if anybody else has done this and has a good place to look or a good blog post to read about this.