Hello everyone,
i have a thermography dataset with an image size of 126x256x256 (channel x heigh x width) and want to perform a segmentation task. I was thinking of adding a 1x1 conv layer between the input and the dynamic U-net . The idea is to reduce the number of channels and therefore make the spatial dependency (256x256 is spatial/ 126 is temporal) easier to learn. The conv 1x1 layer should reduce the dimension from 126 to 20. I tried the following:
Define a custom U-net class (adapted from Dynamic U-net), which adds the conv1 layer in the end
class CustomUnet(SequentialEx):
def __init__(self, encoder, n_out, img_size, blur=False, blur_final=True, self_attention=False,
y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,
init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
imsize = img_size
sizes = model_sizes(encoder, size=imsize)
sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
x = dummy_eval(encoder, imsize).detach()
ni = sizes[-1][1]
middle_conv = nn.Sequential(ConvLayer(ni, ni*2, act_cls=act_cls, norm_type=norm_type, **kwargs),
ConvLayer(ni*2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs)).eval()
x = middle_conv(x)
layers = [ encoder, BatchNorm(ni), nn.ReLU(), middle_conv]
for i,idx in enumerate(sz_chg_idxs):
not_final = i!=len(sz_chg_idxs)-1
up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
do_blur = blur and (not_final or blur_final)
sa = self_attention and (i==len(sz_chg_idxs)-3)
unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
act_cls=act_cls, init=init, norm_type=norm_type, **kwargs).eval()
layers.append(unet_block)
x = unet_block(x)
ni = x.shape[1]
if imsize != sizes[0][-2:]: layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
layers.append(ResizeToOrig())
if last_cross:
layers.append(MergeLayer(dense=True))
ni += in_channels(encoder)
layers.append(ResBlock(1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, **kwargs))
conv1=[nn.Sequential(ConvLayer(126, 20, ks=1, padding=0, act_cls=act_cls, norm_type=norm_type, **kwargs))]
layers=conv1 + layers
layers += [ConvLayer(ni, n_out, ks=1, act_cls=None, norm_type=norm_type, **kwargs)]
apply_init(nn.Sequential(layers[3], layers[-2]), init)
if y_range is not None: layers.append(SigmoidRange(*y_range))
layers.append(ToTensorBase())
super().__init__(*layers)
def __del__(self):
if hasattr(self, "sfs"): self.sfs.remove()
Then create the U-net model and learner:
def create_unet_model1(arch, n_out, img_size, pretrained=True, cut=None, n_in=3, **kwargs):
"Create custom unet architecture"
meta = model_meta.get(arch, _default_meta)
body = create_body(arch, n_in, pretrained, ifnone(cut, meta['cut']))
model = CustomUnet(body, n_out, img_size, **kwargs)
return model
def unet_learner1(dls, arch, normalize=True, n_out=None, pretrained=True, config=None,
# learner args
loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,
model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95),
# other model args
**kwargs):
"Build a unet learner from `dls` and `arch`"
if config:
warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')
kwargs = {**config, **kwargs}
meta = model_meta.get(arch, _default_meta)
if normalize: _add_norm(dls, meta, pretrained)
n_out = ifnone(n_out, get_c(dls))
assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
img_size = dls.one_batch()[0].shape[-2:]
assert img_size, "image size could not be inferred from data"
model = create_unet_model1(arch, n_out, img_size, pretrained=pretrained, **kwargs)
splitter=ifnone(splitter, meta['split'])
learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn,
moms=moms)
if pretrained: learn.freeze()
# keep track of args for loggers
store_attr('arch,normalize,n_out,pretrained', self=learn, **kwargs)
return learn, model
Apply the functions:
learn1, model= unet_learner1(dl, resnet34,normalize=True, n_in=20,n_out=2, pretrained=False, loss_func=loss_fn_ce, metrics=[acc_metric, roc], cbs=[ShowGraphCallback(), CSVLogger(fname='history')])
The model look like this:
CustomUnet(
(layers): ModuleList(
(0): Sequential(
(0): ConvLayer(
(0): Conv2d(126, 20, kernel_size=(1, 1), stride=(1, 1))
(1): ReLU()
)
)
(1): Sequential(
(0): Conv2d(20, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(6): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(4): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(5): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(7): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
(4): Sequential(
(0): ConvLayer(
(0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(1): ConvLayer(
(0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
)
(5): UnetBlock(
(shuf): PixelShuffle_ICNR(
(0): ConvLayer(
(0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
(1): ReLU()
)
(1): PixelShuffle(upscale_factor=2)
)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): ConvLayer(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(conv2): ConvLayer(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(relu): ReLU()
)
(6): UnetBlock(
(shuf): PixelShuffle_ICNR(
(0): ConvLayer(
(0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
(1): ReLU()
)
(1): PixelShuffle(upscale_factor=2)
)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): ConvLayer(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(conv2): ConvLayer(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(relu): ReLU()
)
(7): UnetBlock(
(shuf): PixelShuffle_ICNR(
(0): ConvLayer(
(0): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))
(1): ReLU()
)
(1): PixelShuffle(upscale_factor=2)
)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): ConvLayer(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(conv2): ConvLayer(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(relu): ReLU()
)
(8): UnetBlock(
(shuf): PixelShuffle_ICNR(
(0): ConvLayer(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
(1): ReLU()
)
(1): PixelShuffle(upscale_factor=2)
)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): ConvLayer(
(0): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(conv2): ConvLayer(
(0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(relu): ReLU()
)
(9): PixelShuffle_ICNR(
(0): ConvLayer(
(0): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1))
(1): ReLU()
)
(1): PixelShuffle(upscale_factor=2)
)
(10): ResizeToOrig()
(11): MergeLayer()
(12): ResBlock(
(convpath): Sequential(
(0): ConvLayer(
(0): Conv2d(116, 116, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(1): ConvLayer(
(0): Conv2d(116, 116, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(idpath): Sequential()
(act): ReLU(inplace=True)
)
(13): ConvLayer(
(0): Conv2d(116, 2, kernel_size=(1, 1), stride=(1, 1))
)
(14): fastai.layers.ToTensorBase(tensor_cls=<class 'fastai.torch_core.TensorBase'>)
)
When I try to use the model I get however:
RuntimeError: Given groups=1, weight of size [116, 116, 3, 3], expected input[1, 222, 256, 256] to have 116 channels, but got 222 channels instead
Does anyone know what I am doing wrong here? Is the 1x1 conv layer at the right position in the code and defined correctly?
Thank you for your help
All the best,
Simon