Semantic Segmentation of 3D stacks with Dynamic Unet


I am new in this community and with I am working on a segmantic segmentation problem to segment highly dense clumps of cells from 3D microscopy data, and I am tempted on using the Dynamic Unet in Is it possible to feed 3D data to the Dynamic Unet.

If not is it possible to modify the source code?



no, you would have to redo the arch. But shouldn’t be that hard.
Anyway, for 3d data, it will be utterly expensive on VRAM.
So maybe start with a simple prototype.

  • what size is your data?
    I can leave you a simple Unet (non dynamic) that I built some time ago:
class MyUnetBlock(Module):
    "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
    def __init__(self, up_in_c, x_in_c, final_div=True, blur=False, act_cls=defaults.activation,
                 self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type) = BatchNorm(x_in_c)
        ni = up_in_c//2 + x_in_c
        nf = ni if final_div else x_in_c
        self.conv1 = ConvLayer(ni, nf, act_cls=act_cls, norm_type=norm_type, **kwargs)
        self.conv2 = ConvLayer(nf, nf, act_cls=act_cls, norm_type=norm_type,
                               xtra=SA(nf) if self_attention else None, **kwargs)
        self.relu = act_cls()
        apply_init(nn.Sequential(self.conv1, self.conv2), init)
    def forward(self, up_in, s):
        up_out = self.shuf(up_in)
        ssh = s.shape[-2:]
        if ssh != up_out.shape[-2:]:
            up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
        cat_x = self.relu([up_out,], dim=1))
        return self.conv2(self.conv1(cat_x))
class SimpleUnet(Module):
    def __init__(self, n_in=3, n_classes=7, act_cls=defaults.activation, 
                 norm_type=None, sa=False, l_szs = [32,48,96,128]):
        "A simpler UNET torch.jit compatible"
        self.stem = nn.Sequential(ConvLayer(n_in, 16, ks=3, stride=2, act_cls=act_cls),
                                  ConvLayer(16, l_szs[0], ks=3, stride=1, act_cls=act_cls),
                                  nn.MaxPool2d(3, stride=2, padding=1))
        assert len(l_szs)==4, '3 layers supported (4 values)'
        szs = zip(l_szs[0:-1], l_szs[1:])
        self.l1, self.l2, self.l3 = [self.make_resblock(ni,nf) for (ni,nf) in szs]
        ni = l_szs[-1]
        self.middle_conv = nn.Sequential(BatchNorm(ni), nn.ReLU(),
            ConvLayer(ni, ni * 2, act_cls=act_cls, norm_type=norm_type),
            ConvLayer(ni * 2, ni, act_cls=act_cls, norm_type=norm_type)
        l_szs_r = l_szs[::-1]
        r_szs = zip(l_szs_r[0:-1], l_szs_r[1:])
        unet_blocks = []
        for i, (ni, nf) in enumerate(r_szs):
            unet_blocks += [MyUnetBlock(ni, nf, final_div=(i==len(l_szs)-2), act_cls=act_cls, 
                                       norm_type=norm_type, self_attention=(i==0))]
        self.u3, self.u2, self.u1 = unet_blocks
        self.head = nn.Sequential(ResBlock(1, l_szs[0]+l_szs[1]//2+n_in, 32, stride=1, 
                                  ConvLayer(32, n_classes, ks=1, act_cls=None, norm_type=norm_type))
    def unet_out_sz(up_in_c, x_in_c, final_div):
        ni = up_in_c//2 + x_in_c
        nf = ni if final_div else ni//2
        return nf
    def make_resblock(self, ni, nf):
        return nn.Sequential(ResBlock(1, ni, ni, stride=1), ConvLayer(ni, nf, stride=2))
    def forward(self, x):
        out = self.stem(x)
        out1 = self.l1(out) 
        out2 = self.l2(out1)
        out3 = self.l3(out2)
        res = self.middle_conv(out3)
        res = self.u3(res, out2)
        res = self.u2(res, out1)
        res = self.u1(res, out)
        if res.shape[-2:] != x.shape[-2:]:
            res = F.interpolate(res, x.shape[-2:], mode='nearest')
        res =[x, res], dim=1)
        res = self.head(res)
        return res
def simple_unet_split(m): 
    return [*L(m.l1, m.l2, m.l3).map(params), *L(m.middle_conv, m.u3, m.u2, m.u1, m.head).map(params)]

You would have to

  • edit the hardcoded 2’s and replace them by 3;s.
  • Add the 3d param to the ConvLayer
  • replace the nn.PixelShuffle as it does not support 3d tensors

where you able to make it work?