Stacked model won't work in fast.ai

I would like to build an image classifier that takes four (greyscale) images/channels at a time and output a single class label. I’ve tried making four-channel images, but as the channels are quite different from each other (not RGBY) the model wasn’t really good. Now I’m planning to train one model for each ‘channel’ and then stack all four models parallel to each other, combine their output class scores in the end and add one more fc layer. I’m hoping this model can give me better accuracy and telling me which channel is the most important by looking at the last fc layer (similar to attention).

My method is:

# load pretrained learners
learner_1 = cnn_learner(data_1, ...); learner_1.load(...)
learner_2 = cnn_learner(data_2, ...); learner_2.load(...)
learner_3 = cnn_learner(data_3, ...); learner_3.load(...)
learner_4 = cnn_learner(data_4, ...); learner_4.load(...)

# extract sequential models:
model_1 = learner_1.model
model_2 = learner_2.model
model_3 = learner_3.model
model_4 = learner_4.model

# my model:
class stacked_model(nn.Module):
    def __init__(self, model_1=model_1, model_2=model_2, model_3=model_3, model_4=model_4):
        super().__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.model_3 = model_3
        self.model_4 = model_4
        self.fc = nn.Linear(8,2)

    def forward(self, x):
        # open a four-channel image as four images
        c1  = torch.stack((x[:,0,:,:],x[:,0,:,:],x[:,0,:,:]), dim=1)
        c2  = torch.stack((x[:,1,:,:],x[:,1,:,:],x[:,1,:,:]), dim=1)
        c3  = torch.stack((x[:,2,:,:],x[:,2,:,:],x[:,2,:,:]), dim=1)
        c4  = torch.stack((x[:,3,:,:],x[:,3,:,:],x[:,3,:,:]), dim=1)

        c1_out = self.model_1(c1)
        c2_out = self.model_2(c2)
        c3_out = self.model_3(c3)
        c4_out = self.model_4(c4)

        out = torch.stack((c1_out, c2_out, c3_out, c4_out), dim=1)
        x = out.vew(out.size(0), -1)
        x = self.fc(x)

        return x

The model can run by itself in pytorch, e.g. stacked_model().forward(x), but it won’t run in fastai. The error is:

learner = cnn_learner(data, stacked_model,...)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-26-c4e345be6722> in <module>
      6     #loss_func=F.binary_cross_entropy_with_logits,
      7     path='',
----> 8     metrics=[accuracy]
      9 )

/mnt/sdf/py_torch/lib/python3.6/site-packages/fastai/vision/learner.py in cnn_learner(data, base_arch, cut, pretrained, lin_ftrs, ps, custom_head, split_on, bn_final, init, concat_pool, **kwargs)
     96     meta = cnn_config(base_arch)
     97     model = create_cnn_model(base_arch, data.c, cut, pretrained, lin_ftrs, ps=ps, custom_head=custom_head,
---> 98         split_on=split_on, bn_final=bn_final, concat_pool=concat_pool)
     99     learn = Learner(data, model, **kwargs)
    100     learn.split(split_on or meta['split'])

/mnt/sdf/py_torch/lib/python3.6/site-packages/fastai/vision/learner.py in create_cnn_model(base_arch, nc, cut, pretrained, lin_ftrs, ps, custom_head, split_on, bn_final, concat_pool)
     84     body = create_body(base_arch, pretrained, cut)
     85     if custom_head is None:
---> 86         nf = num_features_model(nn.Sequential(*body.children())) * (2 if concat_pool else 1)
     87         head = create_head(nf, nc, lin_ftrs, ps=ps, concat_pool=concat_pool, bn_final=bn_final)
     88     else: head = custom_head

/mnt/sdf/py_torch/lib/python3.6/site-packages/fastai/callbacks/hooks.py in num_features_model(m)
    119     sz = 64
    120     while True:
--> 121         try: return model_sizes(m, size=(sz,sz))[-1][1]
    122         except Exception as e:
    123             sz *= 2

/mnt/sdf/py_torch/lib/python3.6/site-packages/fastai/callbacks/hooks.py in model_sizes(m, size)
    112     "Pass a dummy input through the model `m` to get the various sizes of activations."
    113     with hook_outputs(m) as hooks:
--> 114         x = dummy_eval(m, size)
    115         return [o.stored.shape for o in hooks]
    116 

/mnt/sdf/py_torch/lib/python3.6/site-packages/fastai/callbacks/hooks.py in dummy_eval(m, size)
    107 def dummy_eval(m:nn.Module, size:tuple=(64,64)):
    108     "Pass a `dummy_batch` in evaluation mode in `m` with `size`."
--> 109     return m.eval()(dummy_batch(m, size))
    110 
    111 def model_sizes(m:nn.Module, size:tuple=(64,64))->Tuple[Sizes,Tensor,Hooks]:

/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     90     def forward(self, input):
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input
     94 

/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     90     def forward(self, input):
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input
     94 

/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     90     def forward(self, input):
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input
     94 

/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/conv.py in forward(self, input)
    336                             _pair(0), self.dilation, self.groups)
    337         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338                         self.padding, self.dilation, self.groups)
    339 
    340 

RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3, but got 2-dimensional input of size [1, 2] instead

It’s strange that if I change the order of self.model_x defined in my stacked_model class, the error message will change. For example, if I put my fc layer nn.Linear(8,2) on top (before self.model_1), the error becomes:

RuntimeError: size mismatch, m1: [16384 x 2048], m2: [8 x 2] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:961
1 Like

Just a thought, try using Learner and not cnn_learner. Cnn_learner expects a split due to creating a custom head

Nope you can use it just as you had. Instead just use Learner. (I’m not 100% sure it’ll work) But it’s a start.

maybe you need to pass in a model object instead of the model class into learner?