no, you would have to redo the arch. But shouldn’t be that hard.
Anyway, for 3d data, it will be utterly expensive on VRAM.
So maybe start with a simple prototype.
- what size is your data?
I can leave you a simple Unet (non dynamic) that I built some time ago:
class MyUnetBlock(Module):
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
@delegates(ConvLayer.__init__)
def __init__(self, up_in_c, x_in_c, final_div=True, blur=False, act_cls=defaults.activation,
self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type)
self.bn = BatchNorm(x_in_c)
ni = up_in_c//2 + x_in_c
nf = ni if final_div else x_in_c
self.conv1 = ConvLayer(ni, nf, act_cls=act_cls, norm_type=norm_type, **kwargs)
self.conv2 = ConvLayer(nf, nf, act_cls=act_cls, norm_type=norm_type,
xtra=SA(nf) if self_attention else None, **kwargs)
self.relu = act_cls()
apply_init(nn.Sequential(self.conv1, self.conv2), init)
def forward(self, up_in, s):
up_out = self.shuf(up_in)
ssh = s.shape[-2:]
if ssh != up_out.shape[-2:]:
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
return self.conv2(self.conv1(cat_x))
class SimpleUnet(Module):
def __init__(self, n_in=3, n_classes=7, act_cls=defaults.activation,
norm_type=None, sa=False, l_szs = [32,48,96,128]):
"A simpler UNET torch.jit compatible"
self.stem = nn.Sequential(ConvLayer(n_in, 16, ks=3, stride=2, act_cls=act_cls),
ConvLayer(16, l_szs[0], ks=3, stride=1, act_cls=act_cls),
nn.MaxPool2d(3, stride=2, padding=1))
assert len(l_szs)==4, '3 layers supported (4 values)'
szs = zip(l_szs[0:-1], l_szs[1:])
self.l1, self.l2, self.l3 = [self.make_resblock(ni,nf) for (ni,nf) in szs]
ni = l_szs[-1]
self.middle_conv = nn.Sequential(BatchNorm(ni), nn.ReLU(),
ConvLayer(ni, ni * 2, act_cls=act_cls, norm_type=norm_type),
ConvLayer(ni * 2, ni, act_cls=act_cls, norm_type=norm_type)
)
l_szs_r = l_szs[::-1]
r_szs = zip(l_szs_r[0:-1], l_szs_r[1:])
unet_blocks = []
for i, (ni, nf) in enumerate(r_szs):
unet_blocks += [MyUnetBlock(ni, nf, final_div=(i==len(l_szs)-2), act_cls=act_cls,
norm_type=norm_type, self_attention=(i==0))]
self.u3, self.u2, self.u1 = unet_blocks
self.head = nn.Sequential(ResBlock(1, l_szs[0]+l_szs[1]//2+n_in, 32, stride=1,
act_cls=act_cls,norm_type=norm_type),
ConvLayer(32, n_classes, ks=1, act_cls=None, norm_type=norm_type))
@classmethod
def unet_out_sz(up_in_c, x_in_c, final_div):
ni = up_in_c//2 + x_in_c
nf = ni if final_div else ni//2
return nf
def make_resblock(self, ni, nf):
return nn.Sequential(ResBlock(1, ni, ni, stride=1), ConvLayer(ni, nf, stride=2))
def forward(self, x):
#encode
out = self.stem(x)
out1 = self.l1(out)
out2 = self.l2(out1)
out3 = self.l3(out2)
#middle
res = self.middle_conv(out3)
#decode
res = self.u3(res, out2)
res = self.u2(res, out1)
res = self.u1(res, out)
#interp
if res.shape[-2:] != x.shape[-2:]:
res = F.interpolate(res, x.shape[-2:], mode='nearest')
res = torch.cat([x, res], dim=1)
res = self.head(res)
return res
def simple_unet_split(m):
return [*L(m.l1, m.l2, m.l3).map(params), *L(m.middle_conv, m.u3, m.u2, m.u1, m.head).map(params)]
You would have to
- edit the hardcoded 2’s and replace them by 3;s.
- Add the 3d param to the ConvLayer
- replace the
nn.PixelShuffle
as it does not support 3d tensors