Learn.export() error when using torch Dataset

Hi,

I have setup a fastai unet learning notebook that uses a custom pytorch Dataset, due to specific augmentation requirements and file locations. However, I am having problems when trying to save the model using the export() function.

Here is the example of some code:

This is how the custom pytorch dataset is setup

import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import tifffile
import pandas as pd
from pathlib import Path
from PIL import Image
from fastai.vision.all import *
from torch.utils.data import Dataset, DataLoader, random_split
import torch
import albumentations as alb
import albumentations.pytorch

class my_Dataset(Dataset):
    def __init__(self,csv_fn, dataroot_path=".", tfms=None):
        '''
        csv_fn is the csv file
        data in the csv file should have filenames of data and respective labels
        in columns "data" and "train"
        '''
        self.df_file_locs=None
        if csv_fn:
            self.df_file_locs= pd.read_csv(csv_fn)
        #print(self.df_file_locs)

        self.tfms=tfms
        self.dataroot_path=dataroot_path

    def __len__(self):
        return len(self.df_file_locs)

    def __getitem__(self, idx):

        # Read file
        drwow=self.df_file_locs.iloc[idx]
        #print(f"idx:{idx}, drwow:{drwow}")
        datafn = drwow['data']
        labelfn= drwow['train']

        data =tifffile.imread(Path(self.dataroot_path) / Path(datafn))
        labels=tifffile.imread(Path(self.dataroot_path) / Path(labelfn)).astype(np.uint8)

        assert data.shape == labels.shape

        # Apply transforms
        if self.tfms:
            res =self.tfms(image=data, mask=labels)
            data=res['image']
            labels=res['mask']

        #return a tuple data, mask
        return data, labels

Data is loaded from tifffile and is in uint16 format. So I setup these augmentations using albumentations

train_tfms=alb.Compose([
            alb.ToFloat(max_value=65535.0),
            alb.Resize(*imgsize),
            alb.Lambda(name="normalize_by_mean_std_with_clip", image=normalize_by_mean_std_with_clip, always_apply=True),
            alb.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=5, p=0.3),
            #alb.HorizontalFlip(p=0.5),
            albumentations.pytorch.ToTensorV2()
            #alb.FromFloat(dtype=np.uint8, max_value=255, always_apply=True)
        ])

The data is split in two different train and test datasets, using the pytorch utility

dset1, dset2 = torch.utils.data.random_split(dataset0, [0.8,0.2])

and then fastai Dataloader can be setup as follows

fastai_dls = DataLoaders.from_dsets(dset1, dset2, bs=4, path="fastai_pt_model")
fastai_dls.cuda()

Now the unet can be created using this dataloaders

learn0 = unet_learner(fastai_dls, models.resnet34, n_in=1, n_out=4, loss_func=CrossEntropyLossFlat(axis=1))

Train, and fine tune

learn0.fit_one_cycle(10, lr_max=0.5e-4)
learn0.fine_tune(50)

And I test predictions, all good. But saving is problematic

import dill
learn0.export("model.dpkl", pickle_module=dill)

and I get the error

AttributeError                            Traceback (most recent call last)
Cell In[39], line 2
      1 import dill
----> 2 learn0.export("model.dpkl", pickle_module=dill)
      3 #Load it using load_lerner("filename.dpkl", pickle_module=dill, cpu=False)
      4 #This saving is not working for fastai learners setup with pytorch Datasets
      5 # It appears that some function is missing in the Dataset class

File ~/miniforge3/envs/p11_devc_els1/lib/python3.11/site-packages/fastai/learner.py:430, in export(self, fname, pickle_module, pickle_protocol)
    428 self._end_cleanup()
    429 old_dbunch = self.dls
--> 430 self.dls = self.dls.new_empty()
    431 state = self.opt.state_dict() if self.opt is not None else None
    432 self.opt = None

File ~/miniforge3/envs/p11_devc_els1/lib/python3.11/site-packages/fastai/data/core.py:214, in DataLoaders.new_empty(self)
    213 def new_empty(self):
--> 214     loaders = [dl.new(dl.dataset.new_empty()) for dl in self.loaders]
    215     return type(self)(*loaders, path=self.path, device=self.device)

File ~/miniforge3/envs/p11_devc_els1/lib/python3.11/site-packages/fastai/data/core.py:214, in <listcomp>(.0)
    213 def new_empty(self):
--> 214     loaders = [dl.new(dl.dataset.new_empty()) for dl in self.loaders]
    215     return type(self)(*loaders, path=self.path, device=self.device)

AttributeError: 'Subset' object has no attribute 'new_empty'

I have attempted to add the new_empty() function in my_Dataset class but that does not help.

Also, if I check learn0.dls.loaders[0].dataset it gives me a torch Subset object.

Similar training but using the fastai’s ImageBlock(cls=PILImageBW, the learn.dls.loaders[0].dataset gives me a long line

(#41) [(PILImageBW mode=L size=512x512, PILMask mode=L size=512x512),(PILImageBW mode=L size=512x512, PILMask mode=L size=512x512),(PILImageBW mode=L size=512x512, PILMask mode=L size=512x512),(PILImageBW mode=L size=512x512, PILMask mode=L size=512x512),(PILImageBW mode=L size=512x512, PILMask mode=L size=512x512),(PILImageBW mode=L size=512x512, PILMask mode=L size=512x512),(PILImageBW mode=L size=512x512, PILMask mode=L size=512x512),(PILImageBW mode=L size=512x512, PILMask mode=L size=512x512),(PILImageBW mode=L size=512x512, PILMask mode=L size=512x512),(PILImageBW mode=L size=512x512, PILMask mode=L size=512x512)...]

I am not sure what the best way to save a model

I thought about using learn.save(), but for inference elsewhere this will require recreation of the unet learner, which requires setting up the dataloaders.

I have also tried using pytorch’s jit export.
torch.jit.script(learn.model)
but I am getting another error.

Any help is appreciated. Thanks

Hi,

There’s no need in fastai?

unet_learner(fastai_dls, models.resnet34, n_in=1, n_out=4, loss_func=CrossEntropyLossFlat(axis=1)).to('cuda')

I don’t want to pose as a big expert, because I’m not. Maybe here this will help you:

https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

Keep attention at this tutorial:

https://pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html?highlight=loading%20nn%20module%20from%20checkpoint

I hope that you find a resolution :slight_smile: