How to handle Gray(1 channel) and RGBA(4 channel ) image with fastai?

I try use different channel image to train a unet model with models.resnet34(fastai v1). But normalize(imagenet_stats)) function will convert gray image to 3channel image, and the RGBA image will error with normalize(imagenet_stats)), So what should i do to train a unet with different channels image?


SegmentationItemList.from_folder(path_img, convert_mode=‘RGBA’)
SegmentationItemList.from_folder(path_img, convert_mode=‘L’)

.normalize(imagenet_stats))

learn = unet_learner(data, models.resnet34, wd=1e-2)

You can call .normalize() without imagenet_stats, and it will grab a batch and compute the stats. This is because your 4-channel imaging will have stats of nested shape ((4,), (4,)) while imagenet_stats has shape ((3,), (3,))

Alternatively you might want to grab a batch and save the stats explicitly, which can be done like

xb, yb = dls.one_batch()
custom_stats = xb.mean(dim=(0, 1, 2))

[...]
  .normalize(*custom_stats)

Don’t forget the asterisk in front of custom_stats!

1 Like

Thank you! I got this point, but the models.resnet34 has fixed 3 channel input, do you know how to change it?

Both unet_learner and cnn_learner have n_in parameter to specify the number of input channels.

Yep as @Pomo mentioned, you can use unet_learner(n_in=4) to set it. If you’re not using a pretrained resnet34, it might be nice to use the Fastai xresnet34 (which also has a way to specify number of input channels via the c_in parameter)

@jwuphysics @Pomo Thank you very much! I think the matter is that my fastai version is too old(1.0.59), and it does not have n_in or c_in parameter. I will try new version, Thanks again!

Error information:
unet_learner(data, arch, pretrained, blur_final, norm_type, split_on, blur, self_attention, y_range, last_cross, bottle, cut, **learn_kwargs)
119 self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
120 bottle=bottle), data.device)
–> 121 learn = Learner(data, model, **learn_kwargs)
122 learn.split(ifnone(split_on, meta[‘split’]))
123 if pretrained: learn.freeze()

TypeError: init() got an unexpected keyword argument ‘n_in’