I would like to build an image classifier that takes four (greyscale) images/channels at a time and output a single class label. I’ve tried making four-channel images, but as the channels are quite different from each other (not RGBY) the model wasn’t really good. Now I’m planning to train one model for each ‘channel’ and then stack all four models parallel to each other, combine their output class scores in the end and add one more fc layer. I’m hoping this model can give me better accuracy and telling me which channel is the most important by looking at the last fc layer (similar to attention).
My method is:
# load pretrained learners
learner_1 = cnn_learner(data_1, ...); learner_1.load(...)
learner_2 = cnn_learner(data_2, ...); learner_2.load(...)
learner_3 = cnn_learner(data_3, ...); learner_3.load(...)
learner_4 = cnn_learner(data_4, ...); learner_4.load(...)
# extract sequential models:
model_1 = learner_1.model
model_2 = learner_2.model
model_3 = learner_3.model
model_4 = learner_4.model
# my model:
class stacked_model(nn.Module):
def __init__(self, model_1=model_1, model_2=model_2, model_3=model_3, model_4=model_4):
super().__init__()
self.model_1 = model_1
self.model_2 = model_2
self.model_3 = model_3
self.model_4 = model_4
self.fc = nn.Linear(8,2)
def forward(self, x):
# open a four-channel image as four images
c1 = torch.stack((x[:,0,:,:],x[:,0,:,:],x[:,0,:,:]), dim=1)
c2 = torch.stack((x[:,1,:,:],x[:,1,:,:],x[:,1,:,:]), dim=1)
c3 = torch.stack((x[:,2,:,:],x[:,2,:,:],x[:,2,:,:]), dim=1)
c4 = torch.stack((x[:,3,:,:],x[:,3,:,:],x[:,3,:,:]), dim=1)
c1_out = self.model_1(c1)
c2_out = self.model_2(c2)
c3_out = self.model_3(c3)
c4_out = self.model_4(c4)
out = torch.stack((c1_out, c2_out, c3_out, c4_out), dim=1)
x = out.vew(out.size(0), -1)
x = self.fc(x)
return x
The model can run by itself in pytorch, e.g. stacked_model().forward(x)
, but it won’t run in fastai. The error is:
learner = cnn_learner(data, stacked_model,...)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-26-c4e345be6722> in <module>
6 #loss_func=F.binary_cross_entropy_with_logits,
7 path='',
----> 8 metrics=[accuracy]
9 )
/mnt/sdf/py_torch/lib/python3.6/site-packages/fastai/vision/learner.py in cnn_learner(data, base_arch, cut, pretrained, lin_ftrs, ps, custom_head, split_on, bn_final, init, concat_pool, **kwargs)
96 meta = cnn_config(base_arch)
97 model = create_cnn_model(base_arch, data.c, cut, pretrained, lin_ftrs, ps=ps, custom_head=custom_head,
---> 98 split_on=split_on, bn_final=bn_final, concat_pool=concat_pool)
99 learn = Learner(data, model, **kwargs)
100 learn.split(split_on or meta['split'])
/mnt/sdf/py_torch/lib/python3.6/site-packages/fastai/vision/learner.py in create_cnn_model(base_arch, nc, cut, pretrained, lin_ftrs, ps, custom_head, split_on, bn_final, concat_pool)
84 body = create_body(base_arch, pretrained, cut)
85 if custom_head is None:
---> 86 nf = num_features_model(nn.Sequential(*body.children())) * (2 if concat_pool else 1)
87 head = create_head(nf, nc, lin_ftrs, ps=ps, concat_pool=concat_pool, bn_final=bn_final)
88 else: head = custom_head
/mnt/sdf/py_torch/lib/python3.6/site-packages/fastai/callbacks/hooks.py in num_features_model(m)
119 sz = 64
120 while True:
--> 121 try: return model_sizes(m, size=(sz,sz))[-1][1]
122 except Exception as e:
123 sz *= 2
/mnt/sdf/py_torch/lib/python3.6/site-packages/fastai/callbacks/hooks.py in model_sizes(m, size)
112 "Pass a dummy input through the model `m` to get the various sizes of activations."
113 with hook_outputs(m) as hooks:
--> 114 x = dummy_eval(m, size)
115 return [o.stored.shape for o in hooks]
116
/mnt/sdf/py_torch/lib/python3.6/site-packages/fastai/callbacks/hooks.py in dummy_eval(m, size)
107 def dummy_eval(m:nn.Module, size:tuple=(64,64)):
108 "Pass a `dummy_batch` in evaluation mode in `m` with `size`."
--> 109 return m.eval()(dummy_batch(m, size))
110
111 def model_sizes(m:nn.Module, size:tuple=(64,64))->Tuple[Sizes,Tensor,Hooks]:
/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
90 def forward(self, input):
91 for module in self._modules.values():
---> 92 input = module(input)
93 return input
94
/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
90 def forward(self, input):
91 for module in self._modules.values():
---> 92 input = module(input)
93 return input
94
/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
90 def forward(self, input):
91 for module in self._modules.values():
---> 92 input = module(input)
93 return input
94
/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
/mnt/sdf/py_torch/lib/python3.6/site-packages/torch/nn/modules/conv.py in forward(self, input)
336 _pair(0), self.dilation, self.groups)
337 return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338 self.padding, self.dilation, self.groups)
339
340
RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3, but got 2-dimensional input of size [1, 2] instead
It’s strange that if I change the order of self.model_x
defined in my stacked_model
class, the error message will change. For example, if I put my fc layer nn.Linear(8,2)
on top (before self.model_1
), the error becomes:
RuntimeError: size mismatch, m1: [16384 x 2048], m2: [8 x 2] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:961