Fastai for 3D data, 2d slicing, with custom dataloader


I have been trying to train a fastai model to deal with 3D data, typically 3dim numpy array with float32 values.

There has bee some past effort implementing it in this package that uses fastai1

What we do to train the 2D fastai model is that we get slices from the data and save them in a specified folder. Then the dataloaders object is setup to pointing the slice image files. We wanted to avoid creating these files with all the slices, and instead create a dataloader(s) that would do that, from a 3D data+labels.

And I would like to upgrade to fastai2 and try to run training from the 3D data more directly from RAM rather than generating files.

This is how far I got (see below), but I am having a few problems, I hope you can give me some advice on how to improve.

  1. In this example the data is split into train and validation datasets. Is there a way to do the same using the Datablock class, that should do it automatically?
  2. show_batch() is not working. I think is because the dataset images are not in PIL image format.
  3. Training seems to be working, but the prediction is not, I’m not sure why
  4. Is it possible to work with the images in float format rather than converting to uint256 and 3-channel RGB (this seems to be the only way it will work to some extent)?

Thank you for reading and for any help.

import numpy as np
import h5py

# Load data
with h5py.File("data.h5") as f:
    data= np.array(f['data'])

# Convert to uint8, clipping to 0-255 range

data_uint8 = ((data- data.min())/(data.max()-data.min())*255).astype(np.uint8)

# Load labels
with h5py.File("labels.h5") as f:
    labels= np.array(f['data'])

# fix labels as they have value 255 for mask, rather than 1
labels= np.where(labels>0, 1, 0).astype(np.uint8)

import torch
from import Dataset

class SegmentationDataset3D_toZslices(Dataset):
    def __init__(self, data, labels,transforms=None):
        self.labels = labels
        self.transforms = transforms

    def __len__(self):

    def __getitem__(self, idx):
        image =[idx,:,:]
        labels = self.labels[idx,:,:]

        if self.transforms:
            image = self.transforms(image)
            labels = self.transforms(labels)
        imgRGB = np.stack([image,image,image], axis=0) # unet learner seems to only accept RGB
        # img_torch = torch.from_numpy(image).float() 
        # lbl_torch = torch.from_numpy(labels).int()
        # return img_torch, lbl_torch
        return imgRGB, labels

#We need train and validation. Split it here into 2 datasets
N0 = N//5
list_z_ind = list(range(data.shape[0]))
import random
random.shuffle(list_z_ind) #shuffle inplace
list_train_ind = list_z_ind[N0:]
list_valid_ind = list_z_ind[0:N0]

data_train = data[list_train_ind,:,:]
labels_train = labels[list_train_ind,:,:]
data_valid= data[list_valid_ind,:,:]
labels_valid = labels[list_valid_ind,:,:]

my_segm_dset_train = SegmentationDataset3D_toZslices(data_train,labels_train)
my_segm_dset_valid = SegmentationDataset3D_toZslices(data_valid,labels_valid)

from import *

my_dls = SegmentationDataLoaders.from_dsets(my_segm_dset_train, my_segm_dset_valid, bs=8) # make data also the validation

my_dls.show_batch() #Does not work

#learn = unet_learner(my_dls, resnet18, n_out=2)
learn = unet_learner(my_dls, resnet18, n_out=2, loss_func=DiceLoss(),metrics=[Dice()])

# Runs, perhaps a bit slow, not sure if it actually using GPU
# Dice=0.977

# Now try to run prediction on the full stack, slice-by-slice

pred0 = learn.predict(np.stack([data_uint8[100,:,:], data_uint8[100,:,:],data_uint8[100,:,:]]))
# RuntimeError: expected scalar type Byte but found Float

d0 = np.expand_dims(np.stack([data_uint8[100,:,:], data_uint8[100,:,:],data_uint8[100,:,:]], axis=0).astype(np.float32), axis=0)
# For some reason I need to provide a 4dim array to run a prediction on a 2D image.
# But if I train by setting up learner+dataloaders with SegmentationDataLoaders.from_label_func(), using tiff files of the slices
# and then run the prediction just by passing the 2D image in numpy format, it works ok.

pred0 = learn.predict(d0)
# Runs prediction but result is weird. There is a small segmented region in the top-left region

import matplotlib.pyplot as plt
fig,ax = plt.subplots(ncols=3, figsize=(10,5))


Hi Luis, it’s possible to work with UNet with 1-channel images. To do so, you can use create_unet_model() [ fastai - Vision learner] to create a custom UNet architecture and specify n_in = 1.

While creating a datablock, it’s possible to pass the argument splitter as RandomSplitter(valid_pct = ..) with valid_pct being fraction of dataset to be put in validation set.

It may be useful to work with Monai library if you want to do 3D segmentation directly. For example, to construct 3D UNet, you can use this constructor Network architectures — MONAI 1.3.0 Documentation specifying spatial_dims = 3.