Hi,
I have been trying to train a fastai model to deal with 3D data, typically 3dim numpy array with float32 values.
There has bee some past effort implementing it in this package that uses fastai1
https://github.com/rosalindfranklininstitute/UnetSegmentation
What we do to train the 2D fastai model is that we get slices from the data and save them in a specified folder. Then the dataloaders object is setup to pointing the slice image files. We wanted to avoid creating these files with all the slices, and instead create a dataloader(s) that would do that, from a 3D data+labels.
And I would like to upgrade to fastai2 and try to run training from the 3D data more directly from RAM rather than generating files.
This is how far I got (see below), but I am having a few problems, I hope you can give me some advice on how to improve.
- In this example the data is split into train and validation datasets. Is there a way to do the same using the Datablock class, that should do it automatically?
- show_batch() is not working. I think is because the dataset images are not in PIL image format.
- Training seems to be working, but the prediction is not, I’m not sure why
- Is it possible to work with the images in float format rather than converting to uint256 and 3-channel RGB (this seems to be the only way it will work to some extent)?
Thank you for reading and for any help.
import numpy as np
import h5py
# Load data
with h5py.File("data.h5") as f:
data= np.array(f['data'])
# Convert to uint8, clipping to 0-255 range
data_uint8 = ((data- data.min())/(data.max()-data.min())*255).astype(np.uint8)
# Load labels
with h5py.File("labels.h5") as f:
labels= np.array(f['data'])
# fix labels as they have value 255 for mask, rather than 1
labels= np.where(labels>0, 1, 0).astype(np.uint8)
import torch
from torch.utils.data import Dataset
class SegmentationDataset3D_toZslices(Dataset):
def __init__(self, data, labels,transforms=None):
self.data=data
self.labels = labels
self.transforms = transforms
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
image = self.data[idx,:,:]
labels = self.labels[idx,:,:]
if self.transforms:
image = self.transforms(image)
labels = self.transforms(labels)
imgRGB = np.stack([image,image,image], axis=0) # unet learner seems to only accept RGB
# img_torch = torch.from_numpy(image).float()
# lbl_torch = torch.from_numpy(labels).int()
# return img_torch, lbl_torch
return imgRGB, labels
#We need train and validation. Split it here into 2 datasets
N=data.shape[0]
N0 = N//5
list_z_ind = list(range(data.shape[0]))
import random
random.shuffle(list_z_ind) #shuffle inplace
list_train_ind = list_z_ind[N0:]
list_valid_ind = list_z_ind[0:N0]
data_train = data[list_train_ind,:,:]
labels_train = labels[list_train_ind,:,:]
data_valid= data[list_valid_ind,:,:]
labels_valid = labels[list_valid_ind,:,:]
my_segm_dset_train = SegmentationDataset3D_toZslices(data_train,labels_train)
my_segm_dset_valid = SegmentationDataset3D_toZslices(data_valid,labels_valid)
from fastai.vision.all import *
my_dls = SegmentationDataLoaders.from_dsets(my_segm_dset_train, my_segm_dset_valid, bs=8) # make data also the validation
my_dls.show_batch() #Does not work
#learn = unet_learner(my_dls, resnet18, n_out=2)
learn = unet_learner(my_dls, resnet18, n_out=2, loss_func=DiceLoss(),metrics=[Dice()])
learn.fine_tune(10)
# Runs, perhaps a bit slow, not sure if it actually using GPU
# Dice=0.977
# Now try to run prediction on the full stack, slice-by-slice
pred0 = learn.predict(np.stack([data_uint8[100,:,:], data_uint8[100,:,:],data_uint8[100,:,:]]))
# RuntimeError: expected scalar type Byte but found Float
d0 = np.expand_dims(np.stack([data_uint8[100,:,:], data_uint8[100,:,:],data_uint8[100,:,:]], axis=0).astype(np.float32), axis=0)
# For some reason I need to provide a 4dim array to run a prediction on a 2D image.
# But if I train by setting up learner+dataloaders with SegmentationDataLoaders.from_label_func(), using tiff files of the slices
# and then run the prediction just by passing the 2D image in numpy format, it works ok.
pred0 = learn.predict(d0)
# Runs prediction but result is weird. There is a small segmented region in the top-left region
import matplotlib.pyplot as plt
fig,ax = plt.subplots(ncols=3, figsize=(10,5))
ax[0].imshow(data_uint8[100,:,:])
ax[1].imshow(labels[100,:,:])
ax[2].imshow(pred0[0])
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[2].set_axis_off()