Support for 3D images in fastai and options for image generators similar to keras

I created a 3D network(pytorch) for classification, is it possible to use fastai using a 3D image generator. I want to use a part of the whole 3D image and create a sliding window image generator instead of using the whole 3D image.

You can do anything in fastai
This will require you to write a custom ItemList in the data block API. See this tutorial or check the source code of ImageItemList.

2 Likes

Sure will look into it. thank you

Hi @sgugger,
I’ve been working (for a while) on using Fastai for MRI images segmentation following the tutorial.
A single MRI image is a stack of 2D images (slices). So we don’t have a 1 to 1 association between file and image. Here it’s 1 file to N images.

Thus I built a “MRI2DSlice” item class which encapsulates a list of Fastai’s Image class. Thus MRI2DSlice has shape: [number_slices, 3, height, width]. To show the MRI2DSlice object, I used the to_one method shown in the tutorial (ImageTupleList) to stack all slices into one Image object to use the show method.
Also it redefines all the Image or ItemBase interface
by applying on each slice (Image class) individually.

Also I want the collate_fn function to not add an outer dimension with the batch size,
rather only add on the 1st dimension the total number of slices of all MR2DSlice
of the batch. Thus the output will be: [\sum_{i=1}^{batchsize}nslices_i, 3, height, width]
A kind internal batch size vs user-defined ones. Now the shape is good for using The Unet_learner.
However the issue is that when I want to show the results this mismatch between batch sizes (internally and externally) causes problems.

For example: Batch size = 2 and The MRI2DSlice objects are:
mri1: shape [10, 3, 100, 100], mri2: shape [11, 3, 100, 100]
label_mri1: shape [10, 1, 100, 100], label_mri2: shape [11, 1, 100, 100]
After collate_fn, we have:
mri_batch_tensor : shape [21, 3, 100, 100]
mri_label_tensor: shape [21, 1, 100, 100]
After training, inputs, predictions and targets tensors have same shape as above.
An input and target/prediction reconstructed back will have shape [3, 100, 100] , [1, 100, 100] instead of mri1/label_mri1.

So I knew a need a way to keep track of the number of slices of each MRI2DSlice objects in the batch to correctly reconstruct them but I’ll have to modify the learner.

My last solution:
Using callbacks, so now I don’t use custom collate_fn, I wrote a :

class AggregateSlicesCB(Callback):
def on_batch_begin(self, **kwargs):
xb = kwargs['last_input'] #shape [batch_size, Number_slices, 3, H, W]
self.bs = len(xb)

xb = torch.cat([*xb])
return xb, kwargs['last_target']

def on_loss_begin(self, **kwargs):
out = reshape_tensors(kwargs['last_output'], self.bs)
kwargs['last_input'] = reshape_tensors(kwargs['last_input'], self.bs)
return out


Now slices are aggregated in the on_batch_begin method for model consumption. Note: It works only when all MRI2DSlice objects share the same number of slices within the batch, for corner cases as in the example above, I have no solution yet!.
Then I reshape the out from the model using the batch size to make it similar to the target for applying the loss function. On the way I adjusted the axis value for the loss function, using 2 instead of 1 or 0.
Then it trains correctly but again for showing the results, I’ve noticed that it reads the last input and and output from CallbackHandler state dict. As I transformed the input after the on_batch_begin I have a mismatch between shapes again. In fact, I’ve tried to reset the last_input from the on_begin_loss but it updates only the out.

My current solution is to hard-modify the RecordOnCPU callback to reshape the input after reading from the CallbackHandler state dict. Now, I’m asking why not allowing us to update any key of the state dict whenever we want as I see it too restricted.

Sorry for the long post. I want it to be clear enough so that you can give me some ideas or completely different workarounds to reach the same goal.
Along the way, Bravo for the flexibility of the framework, it is easy for one to navigate through it.

Thank you so much.

I have decided against it because it won’t catch any typo a user makes. For instance, if you type kwargs['last_ipnut'] =... there wouldn’t be any failure in your code which would lead to a hard time debugging. I know this lacks a bit of flexibility, maybe we should add an argument somewhere, possibly an attribute of the learner, to specify other authorized keys. I didn’t get which key you wanted to add though?

Yes That’ll be great indeed !

“last_input”.
And I think only “last_input”, “last_output” and “last_target” deserve this flexibility as they move between the Data block API side and the Training loop side while others are more specific and only used within Training loop.

I don’t understand, last_input is among the keys you can change.

What I do now is modifying the last_input but its new shape is unsuitable for display purpose so I want to restore it back then in another place than on_batch_begin…But now I understand better the problem… In my first post I wanted to restore last_input in the on_loss_begin method. That means even if it’s possible the error will persist as RecordOnCPU reads last_input at the on_batch_begin level not later during on_loss_begin. So I don’t know how to solve this problem