Custom data loader for 3D data

I’m interested in using the fastai library to learn a regression function from the intensities of one MR image to another MR or CT image (example paper: Deep MR to CT Synthesis using Unpaired Data). That is, I want to load in a 3D volume, pass it through a neural network, where the target is another 3D volume of the same dimension/shape, minimizing some loss like L2-norm.

I can do this in tf, pytorch, keras, but this library has some additional functionality and testing/visualization capabilities which seem great. However, I am having trouble figuring out how to setup a dataset loader that is appropriate for the files I would like to load, specifically, NIfTI (.nii) files. I believe I just need to setup a custom dataset loader which would use nibabel to open the NIfTI files. Can someone point me to the right function/class to modify? I see the ImageDataset class and an example of how to modify it, but I’m not sure if this class is appropriate for my task.

Also, are 3D volumes compatible with the library? Or would I just have to extract 2D patches from the 3D volumes and fit to those?

Thank you for the help.


You can just create a regular PyTorch Dataset - subclassing any of the existing fastai classes is a fine way to do so, if that helps you. Then everything else should “just work” :slight_smile:

3D should work fine, although you’ll need to build your own PyTorch model - ConvLearner can’t do it for you.


Thanks for the quick reply. For posterity, I was able to get fastai working with a custom PyTorch Dataset. Since I didn’t find any NIfTI specific dataloaders that were very general, I created my own with a corresponding GitHub repo [link] that is hopefully easy to install and use (at least for the use case I described in my initial question).

I have a notebook with a corresponding (silly) example of how to get it working with fastai which is also in the repo [link]. Thanks for the help and great work with the package!


@jcreinhold that’s great! You might really enjoy using the new fastai transforms too. To see how, just follow the steps we use in ImageDataBunch.create:

Let me know if you give it a go, and if you need any help. I’d love to be able to show folks a really nice role model of how to create a custom dataset.


Thanks for the feedback, I am now trying to get the fastai transforms working per Jeremy’s suggestion. My work so far can be seen in this notebook (if that link fails here is an alternate link).

I was able to get transforms working by first converting my images to fastai.Image classes. Should the transforms work on torch.Tensor objects? I believe the message I was getting was something along the lines of set_sample attribute missing.

Also, I was only able to get the transforms working on the source (x) and target (y) images separately. That is, the same parameters of the transforms were not used for both x and y. While this is useful, I would also like to apply the transforms to both simultaneously. Does the ImageDataBunch class support this functionality with transforms as is? Or would I need to create my own transforms for this use case?

1 Like

I have the similar questions. Do you have any plans to develop a 3D ConvLearner in the near future? Keras has Conv3D(). Because fastai library has many great additional functions, will you develop some APIs for keras users? Many thanks.

Hi @xuzhang. For your information, I was able to get the Learner class to work with a 3D model. See here for an example.

If you write a valid 3D PyTorch model and setup the appropriate DataBunch class (for instance, like the one I developed in the niftidataset repo), then you can call, learner.lr_find(), or whatever. Like Jeremy stated previously, it does “just work.”

Not sure about Keras support though, someone else will have to chime in regarding that.


Hi @jcreinhold,
Thank you so much for your fast response.
My dataset is not images. The shape of data is like ( nb_example, nb_channels(>1), x, y, z). I modified WideResnet to 3D convolution NN using Keras library and trained my model smoothly. So, if I want to use fastai library, I have to modify WideResnet to 3D NN using Pytorch first, then I can use and other methods. Am I right?

You are correct. If you create a WideResnet Pytorch model, then you should be able to input that into the Learner class like I did in the example in my last post. However, you will also (probably) have to create a custom dataset class which you can then instantiate with a DataBunch fastai class.

To say that all in a different way, if you have your pytorch WideResnet model in a variable named model and you create a DataBunch class instance called data with your custom dataset. Then you should be able to call

learner = Learner(data, model, loss_func=..., metrics=[...])

and have a working fastai learner which supports the .fit, .find_lr, etc. Someone can (please) correct me if I’m wrong about any of what I said, or chime in if there is an easier way to do this.

1 Like

Again—for posterity—I was able to get parallel transforms working on both the source and target images. See this notebook for an example.

The part that I missed was to pass along the tfm_y = True flag when creating the DataBunch, which produces the desired behavior.

If anyone has any additional feedback on how to make the custom dataset better or use more of the fastai API, feedback/comments are appreciated.

Thanks for the help and the great work!


Great work @jcreinhold!
I am struggling with the data loader.
I am trying to read in a ndarray (img) with shape torch.Size([86, 85, 70]) .
The moment I say faiv.Image(torch.Tensor(img))
it fails, because the last dimension is >3 ( I guess).

~/anaconda3/envs/fastai_dl_course/lib/python3.6/site-packages/matplotlib/ in _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification, unsampled, round_to_pixel_border)
484 A = _rgb_to_rgba(A)
485 elif A.shape[2] != 4:
–> 486 raise ValueError(“Invalid dimensions, got %s” % (A.shape,))
488 output = np.zeros((out_height, out_width, 4), dtype=A.dtype)

ValueError: Invalid dimensions, got (85, 86, 70)

What dimension is your data?


Hey Fabian,

Thanks! For future reference, that notebook was back in fastai version 1.0.15 and it’s currently at version 1.0.39, with many breaking changes. So the success of your code will definitely depend on what version you are running.

Looking at the stacktrace, it would appear the problem is with trying to plot the image since the error is raised from matplotlib. If you are running this in a notebook and you are calling faiv.Image(torch.tensor(img)) in a cell by itself, then fastai will try to plot the image. Of course, you’ll note that you’re data is not in a format to be displayed without some additional processing. In the function where I call faiv.Image..., I am using that as a preprocessing step and the result is never output alone and not within the framework of a dataloader for NN training.

Just to make this more concrete, if I open up a jupyter notebook and type


I will receive the same error you are receiving, but that is because the (100,100,100) data is not an image that matplotlib knows how to show. So if you are building a dataloader for your size (86,85,70) data, then you should pipe this result to another function and not run it in a cell alone.

However, I think my notebook example is now completely out-of-date. If you have non-natively-supported data, you should write your own ItemList for loading the data. Here is the official tutorial. FWIW, I also wrote about how to create a custom data loader for a specific type of 3D data here (see the Experiments section near the bottom for some code snippets).

Hopefully that was helpful! Let me know if you need any clarifications.



Thanks for the quick and detailed reply. That is really useful. I was able to fix my problem already and will have a look at your linked blog post!

Hi Jake,

Just wanted to share a link to a github page:

That guy has also been looking into using fastai for 3D images. Pretty cool!



@jcreinhold thanks for the niftidataset. I am having an issue defining a class AddChannel. Do you know why is this?

NameError                                 Traceback (most recent call last)
<ipython-input-7-416698beb24e> in <module>()
----> 1 class AddChannel:
      2     """ Add channel dimension to sample """
      3     def __call__(self, sample: Tuple[np.ndarray, np.ndarray]):
      4         src, tgt = sample
      5         assert src.shape == tgt.shape

<ipython-input-7-416698beb24e> in AddChannel()
      1 class AddChannel:
      2     """ Add channel dimension to sample """
----> 3     def __call__(self, sample: Tuple[np.ndarray, np.ndarray]):
      4         src, tgt = sample
      5         assert src.shape == tgt.shape

NameError: name 'Tuple' is not defined