SemTorch: A Semantic Segmentation library build above FastAI

Hi, guys:

I am happy to announce that I have released SemTorch.

This library allows you to train 5 different Sementation Models: UNet, DeepLabV3+, HRNet, Mask-RCNN and U²-Net in the same way.

For example:

# SemTorch
from semtorch import get_segmentation_learner

learn = get_segmentation_learner(dls=dls, number_classes=2, segmentation_type="Semantic Segmentation",
                                 architecture_name="deeplabv3+", backbone_name="resnet50", 
                                 metrics=[tumour, Dice(), JaccardCoeff()],wd=1e-2,
                                 splitter=segmentron_splitter).to_fp16()

This library was used in my other project: Deep-Tumour-Spheroid. In this project I trained segmentation models for segmenting brain tumours.

The notebooks can be found here. They are an example of how easily is to train a model with this library. You can use SemTorch with your own datasets!

In addition, if you want to know more about this project you can go to

Deeper look in all the parameters of Semtorch
All this library is focused in this function that will get new models and options over time.

def get_segmentation_learner(dls, number_classes, segmentation_type, architecture_name, backbone_name,
                             loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, 
                             cbs=None, pretrained=True, normalize=True, image_size=None, metrics=None, 
                             path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,
                             moms=(0.95,0.85,0.95)):

This function return a learner for the provided architecture and backbone

Parameters:

  • dls (DataLoader): the dataloader to use with the learner
  • number_classes (int): the number of clases in the project. It should be >=2
  • segmentation_type (str): just Semantic Segmentation accepted for now
  • architecture_name (str): name of the architecture. The following ones are supported: unet, deeplabv3+, hrnet, maskrcnn and u2^net
  • backbone_name (str): name of the backbone
  • loss_func (): loss function.
  • opt_func (): opt function.
  • lr (): learning rates
  • splitter (): splitter function for freazing the learner
  • cbs (List[cb]): list of callbacks
  • pretrained (bool): it defines if a trained backbone is needed
  • normalize (bool): if normalization is applied
  • image_size (int): REQUIRED for MaskRCNN. It indicates the desired size of the image.
  • metrics (List[metric]): list of metrics
  • path (): path parameter
  • model_dir (str): the path in which save models
  • wd (float): wieght decay
  • wd_bn_bias (bool):
  • train_bn (bool):
  • moms (Tuple(float)): tuple of different momentuns

Returns:

  • learner: value containing the learner object

Supported configs

Architecture supported config backbones
unet Semantic Segmentation,binary Semantic Segmentation,multiple resnet18, resnet34, resnet50, resnet101, resnet152, xresnet18, xresnet34, xresnet50, xresnet101, xresnet152, squeezenet1_0, squeezenet1_1, densenet121, densenet169, densenet201, densenet161, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn, alexnet
deeplabv3+ Semantic Segmentation,binary Semantic Segmentation,multiple resnet18, resnet34, resnet50, resnet101, resnet152, resnet50c, resnet101c, resnet152c, xception65, mobilenet_v2
hrnet Semantic Segmentation,binary Semantic Segmentation,multiple hrnet_w18_small_model_v1, hrnet_w18_small_model_v2, hrnet_w18, hrnet_w30, hrnet_w32, hrnet_w48
maskrcnn Semantic Segmentation,binary resnet50
u2^net Semantic Segmentation,binary small, normal
51 Likes

What a nice work!

I was wondering if more models could be added. Concretely, I was thinking on HoverNet (https://github.com/vqdang/hover_net)

Yes, thats the objective. Keep adding different models over time! I’ll take a look into this new architecture.

Let me know all the features that you want.

@Joan Give it a star or share it if you find it useful! It will help me!

2 Likes

This is awesome! Just wondering why not integrate this into fastai itself?

3 Likes

It would be nice! I needed to make several workarounds to make it work as an external package.

What do you think @jeremy?

2 Likes

Thanks a lot a have been waiting for something like this for a while now.
After practicing fastai v1, for a year, I got expert in it top down, was hesitant in checking the new one, but this has done, the job , my next 30 days, check the hell out of fastai 2

1 Like

Hi WaterKnight Hope you are having an excellent day!
Great work, makes me wonder what medicine will be like in a 100 years time.

Cheers mrfabulous! :smiley: :smiley:

Thank you @mrfabulous1 @shadab.sayeed

Amazing! Thank you for sharing this.

Quick Q: It says that only Semantic Segmentation is supported but isn’t MaskRCNN an Instance Segmentation architecture? Also in this notebook of yours, I see a bunch of TensorBBoxes and the results show bounding boxes as well. Am I reading this correctly?

Again, thank you for your work!

You are. To quote an article about MaskRCNN:

Object Instance Segmentation is a recent approach that gives us best of both worlds. It integrates object detection task where the goal is to detect object class along with bounding box prediction in an image and semantic segmentation task, which classifies each pixel into pre-defined categories Thus, it enables us to detect objects in an image while precisely segmenting a mask for each object instance.

So, we must have the bounding boxes to pair with the segmentation masks in order to train the model. - Note: If we were to follow the original implementation verbatum

This repo also has a good visual walkthrough of the architecture/what the pipeline could look like in 6 steps: https://github.com/matterport/Mask_RCNN

That being said though, if you read a bit further it is possible to train a model without any ground truth bounding boxes:

Bounding Boxes : Some datasets provide bounding boxes and some provide masks only. To support training on multiple datasets we opted to ignore the bounding boxes that come with the dataset and generate them on the fly instead. We pick the smallest box that encapsulates all the pixels of the mask as the bounding box. This simplifies the implementation and also makes it easy to apply image augmentations that would otherwise be harder to apply to bounding boxes, such as image rotation.

To validate this approach, we compared our computed bounding boxes to those provided by the COCO dataset. We found that ~2% of bounding boxes differed by 1px or more, ~0.05% differed by 5px or more, and only 0.01% differed by 10px or more.

(Granted this is in TF, so if @WaterKnight wanted to try to implement this adjustment it may be slightly more challenging, but I feel it would be worth it, as I would certainly be using it non-stop :slight_smile: )

6 Likes

I think you mean ground truth bboxes?

We pick the smallest box that encapsulates all the pixels of the mask as the bounding box.

1 Like

Thank you very much for taking your time looking at my notebooks! Also thank you very much for finding it useful!

Yes, I tried to use Mask-RCNN as a semantic segmentation model for tumours. This library has been created from all the code of my final degree project from university. I tried to decouple it from the project but as I am working full time right now, I tried to publish it as fast as possible.

I am thinking of extending it in the future because I have received lot of possitive feedback and lot of people find it useful!

The BoundingBox is generated from the Mask on the fly. So if you want to apply any kind of augmentation, you’ll just need to apply at Mask.

The fact why I’m saying that

Semantic Segmentation - Binary is supported

is because the metrics and visualization is very docked to this problem.

As you can see:

What do you suggest me?

I did, edited :sweat_smile:

Very interesting. Thanks for the response.

I’ve managed to tweak the torchvision MaskRCNN reference code to work with MobileNetV2. That notebook is a mess, but I’ll be happy to clean it up and upload on git if you think that’s useful?

It can be found on cell #12 from https://github.com/WaterKnight1998/Deep-Tumour-Spheroid/blob/develop/notebooks/Mask-RCNN.ipynb

def get_bbox(o):
    label_path = get_y_fn(o)
    mask=PILMask.create(label_path)
    pos = np.where(mask)
    xmin = np.min(pos[1])
    xmax = np.max(pos[1])
    ymin = np.min(pos[0])
    ymax = np.max(pos[0])
    
    return TensorBBox.create([xmin, ymin, xmax, ymax])
2 Likes

Well my question would have been does your library currently support only segmentation mask labels for MaskRCNN, but it sounds like you’ve done this, no? :eyes:

It sounds great! I’ll be delighted to see it. If you see in the notebook it supports torchscript. Do you support it in your backbone? Is FP16 working? I opened an issue on torchvision repo and they made

Yes, that’s right. If you look into the learner closely. The training is very normal.

The problem is with data loading, metrics and visualization! In my final degree project I didn’t have time to make it more general. But as @rsomani95 and @muellerzr would like to have it, I can give it a try!

1 Like

No clue, I haven’t tried that yet, as the card I have access to currently doesn’t support FP16. I’ll clean it up and share here once it’s up on git

Okey, thank you very much!

@rsomani95 @muellerzr Maybe it will take me some time to improve Mask-RCNN as I don’t have my RTX card right now, I was waiting to purchase RTX 3080, but there is no stock…

I’ll turn on a server on the cloud or use Google Colab. I will update here with my progress.

3 Likes

@WaterKnight I’m a bit rusty with the components of Mask R-CNN, but I think this code chunk is what you want for MobileNet-V2:

import torchvision.models as models
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models._utils import IntermediateLayerGetter

backbone = models.mobilenet_v2(pretrained=True).features
backbone.out_channels = 1280

model = MaskRCNN(
    #IntermediateLayerGetter needs the {'17': '0', '18':'0'} etc. references
    backbone = BackboneWithFPN(backbone, {'17': '0', '18': '1'}, [320, 1280], 256),
    .... #other args
)



1 Like