Loading Torch Tensors into the DataBlock API

Hello,

I’m working on a segmentation problem and attempting to repeat earlier work done in PyTorch and Tensorflow using the fast.ai framework. As I need to handle the data in a very particular way, the easiest way to ensure proper handling has been to save out the individual tensors for the images and the masks, then read them in during training. For my torch implementation I saved these out in the recommended .pt format. I do not believe my application is compatible with conversion to standard imaging formats (.png etc.).

My first attempt was to reuse my torch DataLoaders, but this seems to cause difficult to diagnose errors under the hood of unet_learner (specifically unet_learner’s default parameter pretrained is being passed as a kwarg into a function that doesn’t recognize it, my fast.ai version is 2.0.13, the stack trace is unavailable right now). I’d like to use fast.ai’s DataBlock API to set up the DataLoader instead, but I’m having trouble figuring out how to do that with non-standard formats.

Does anyone have an example of building a DataLoader using the DataBlock API and data saved as torch tensors?

There’s likely a better way to do this but here’s the implementation I came up with to load some dummy tensors.

from fastai.vision.all import *


def get_mask(fname, masklabel, tensorlabel):
    path_pieces = fname.parts
    new_path = []
    for i in range(len(path_pieces)):
        if path_pieces[i] == tensorlabel:
            new_path.append(masklabel)
            continue
        new_path.append(path_pieces[i])
    mask_path = os.sep.join(new_path)
    return Path(mask_path)


def MaskGetter(masklabel=None, tensorlabel=None):
    "Create `get_mask` partial function that replaces the folder tensorlabel with the folder masklabel and returns the path"
    def _inner(o, masklabel=masklabel, tensorlabel=tensorlabel):
        return get_mask(o, masklabel, tensorlabel)
    return _inner


def _parent_idxs(items, name):
    def _inner(items, name): return mask2idxs(Path(o).parent.name == name for o in items)
    return [i for n in L(name) for i in _inner(items, n)]


def ParentSplitter(train_name='train', valid_name='valid'):
    "Split `items` from the parent folder names (`train_name` and `valid_name`)."
    def _inner(o):
        return _parent_idxs(o, train_name), _parent_idxs(o, valid_name)
    return _inner


class TensorLoad(Transform):
    def __init__(self):
        pass

    def encodes(self, o):
        return torch.load(o)


def TensorBlock():
    return TransformBlock(type_tfms=TensorLoad, batch_tfms=IntToFloatTensor)


def TensorMaskBlock(codes=None):
    return TransformBlock(type_tfms=TensorLoad, item_tfms=AddMaskCodes(codes=codes), batch_tfms=IntToFloatTensor)


def torch_tensor_creator():
    tensor_list = ["tensor_1", "tensor_2", "tensor_3", "tensor_4", "tensor_5", "tensor_6", "tensor_7", "tensor_8"]
    train_tensor_path = r"C:\Users\Joe\Downloads\torch_tensor_test\tensors\train"
    valid_tensor_path = r"C:\Users\Joe\Downloads\torch_tensor_test\tensors\valid"
    train_mask_path = r"C:\Users\Joe\Downloads\torch_tensor_test\masks\train"
    valid_mask_path = r"C:\Users\Joe\Downloads\torch_tensor_test\masks\valid"

    for tensor_name in tensor_list:
        train_tensor = torch.rand((128, 128), dtype=torch.float32)
        valid_tensor = torch.rand((128, 128), dtype=torch.float32)
        train_mask = torch.randint(high=1, size=(128, 128), dtype=torch.float32)
        valid_mask = torch.randint(high=1, size=(128, 128), dtype=torch.float32)

        torch.save(train_tensor, os.path.join(train_tensor_path, tensor_name + '.pt'))
        torch.save(valid_tensor, os.path.join(valid_tensor_path, tensor_name + '.pt'))
        torch.save(train_mask, os.path.join(train_mask_path, tensor_name + '.pt'))
        torch.save(valid_mask, os.path.join(valid_mask_path, tensor_name + '.pt'))

    return None


if __name__ == '__main__':
    path = r"C:\Users\Joe\Downloads\torch_tensor_test"

    datablock = DataBlock(
                    blocks=(TensorBlock, TensorMaskBlock(codes=['ones'])),
                    get_items=FileGetter(extensions='.pt', folders=['tensors']),
                    get_y=MaskGetter(masklabel='masks', tensorlabel='tensors'),
                    splitter=ParentSplitter(train_name='train', valid_name='valid'))

    dataloaders = datablock.dataloaders(Path(path), bs=2)
    print(dataloaders.one_batch())

Here’s a revision that will actually train a model. The tensors I saved out to test this are meaningless but it might open up loading generic data structures as long as you can convert the data into torch tensors of a shape that’s compatible with the model you are using. I still need to test how it works with FastAI’s data augmentation routines and multiprocessing.

from fastai.vision.all import *

def _parent_idxs(items, name):
    def _inner(items, name):
        return mask2idxs(Path(o).parent.name == name for o in items)
    return [i for n in L(name) for i in _inner(items, n)]

def ParentSplitter(train_name='train', valid_name='valid'):
    def _inner(o):
        return _parent_idxs(o, train_name), _parent_idxs(o, valid_name)
    return _inner

# these two classes and the block definitions afterward are the main important/new part
class TorchTensorImage(TensorImage):
    @classmethod
    def create(cls, fn: (Path, str)) -> None:
        tens = torch.load(fn)
        return cls(tens)

class TorchTensorMask(TensorMask):
    @classmethod
    def create(cls, fn: (Path, str)) -> None:
        tens = torch.load(fn)
        return cls(tens)

def TensorImageBlock():
    return TransformBlock(type_tfms=TorchTensorImage.create, batch_tfms=IntToFloatTensor)

def TensorMaskBlock(codes):
    return TransformBlock(type_tfms=TorchTensorMask.create, item_tfms=AddMaskCodes(codes=codes), batch_tfms=IntToFloatTensor)

def torch_tensor_creator():
    tensor_list = ["tensor_1", "tensor_2", "tensor_3", "tensor_4", "tensor_5", "tensor_6", "tensor_7", "tensor_8"]

    train_tensor_path = r"C:\Users\Joe\Downloads\torch_tensor_test\tensors\train"
    valid_tensor_path = r"C:\Users\Joe\Downloads\torch_tensor_test\tensors\valid"
    train_mask_path = r"C:\Users\Joe\Downloads\torch_tensor_test\masks\train"
    valid_mask_path = r"C:\Users\Joe\Downloads\torch_tensor_test\masks\valid"

    for tensor_name in tensor_list:
        train_tensor = torch.rand((128, 128), dtype=torch.float32)
        valid_tensor = torch.rand((128, 128), dtype=torch.float32)
        train_mask = torch.randint(high=2, size=(128, 128), dtype=torch.float32)
        valid_mask = torch.randint(high=2, size=(128, 128), dtype=torch.float32)

        torch.save(train_tensor, os.path.join(train_tensor_path, tensor_name + '.pt'))
        torch.save(valid_tensor, os.path.join(valid_tensor_path, tensor_name + '.pt'))
        torch.save(train_mask, os.path.join(train_mask_path, tensor_name + '.pt'))
        torch.save(valid_mask, os.path.join(valid_mask_path, tensor_name + '.pt'))
   
 return None

if __name__ == '__main__':
    #torch_tensor_creator()
    #train_tensor_path = r"C:\Users\Joe\Downloads\torch_tensor_test\tensors\train"
    #valid_tensor_path = r"C:\Users\Joe\Downloads\torch_tensor_test\tensors\valid"
    #train_mask_path = r"C:\Users\Joe\Downloads\torch_tensor_test\masks\train"
    #valid_mask_path = r"C:\Users\Joe\Downloads\torch_tensor_test\masks\valid"
    path = r"C:\Users\Joe\Downloads\torch_tensor_test"

    def get_msk(o):
        path_pieces = o.parts
        new_path = []
        for i in range(len(path_pieces)):
            if path_pieces[i] == 'tensors':
                new_path.append('masks')
                continue
            new_path.append(path_pieces[i])
        mask_path = os.sep.join(new_path)
        return Path(mask_path)

    datablock = DataBlock(
                    blocks=(TensorImageBlock, TensorMaskBlock(codes=['ones'])),
                    get_items=FileGetter(extensions='.pt', folders=['tensors']),
                    get_y=get_msk,
                    splitter=ParentSplitter(train_name='train', valid_name='valid'))

    dataloaders = datablock.dataloaders(Path(path), bs=1, num_workers=0)

    modelsavepath = r"C:\Users\Joe\Downloads\torch_tensor_test"
    modelsavedir = r"\model\\"

    learn = unet_learner(dls=dataloaders, arch=resnet34, n_out=2, loss_func=CrossEntropyLossFlat(axis=1), path=modelsavedir, model_dir=modelsavedir)

    learn.fit(10)

Probably the final revision for awhile, I added a show method that should coerce arbitrary tensors into displayable images as long as you manage the channels. It looks like it’s ok with multiprocessing as long as your data is not saved as cuda tensors (change the loading function and you should be fine loading numpy files or whatever else, though you’ll probably explicitly need to convert to tensors before the create method returns the data). Some augmentations seem to work, like Rotate, but aug_transforms gives me some weird results and I haven’t dug deep enough to figure out which augmentations are ok and which are not.

Just pasting the relevant code for the TransformBlocks since that’s where the only meaningful changes are.

def _fig_bounds(x):
    # not sure why this is here but fast.ai uses it in their show_image method
    r = x//32
    return min(5, max(1,r))


def scale_tensor_to_image(im_tensor):
    # scales the tensor values to fit between 0 and 1
    im_min = torch.min(im_tensor)
    im_max = torch.max(im_tensor)
    return (255*(im_tensor - im_min)//(im_max - im_min)).type(torch.uint8)


def show_tensor_image(im, ax=None, figsize=None, title=None, ctx=None, **kwargs):
    # basically fast.ai's show_image with an additional function to scale the values
    # between 0 and 1 for arbitrary tensors
    "Show a PIL or PyTorch image on `ax`."
    # Handle pytorch axis order
    if hasattrs(im, ('data','cpu','permute')):
        im = im.data.cpu()
        if im.shape[0]<5: im=im.permute(1,2,0)
    elif not isinstance(im,np.ndarray): im=array(im)
    # Handle 1-channel images
    if im.shape[-1]==1: im=im[...,0]
    
    if torch.max(im) != torch.min(im):
        im = scale_tensor_to_image(im)
    else:
        im = im.type(torch.uint8)

    ax = ifnone(ax,ctx)
    if figsize is None: figsize = (_fig_bounds(im.shape[0]), _fig_bounds(im.shape[1]))
    if ax is None: _,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, **kwargs)
    if title is not None: ax.set_title(title)
    ax.axis('off')
    return ax


class TorchTensorImage(TensorImage):
    # defines the fast.ai TransformBlock needed for torch tensors to work within fast.ai DataBlocks
    # TensorImage base class seems to work best but the other two sort of work
    #class TorchTensorImage(Tensor):
    #class TorchTensorImage(TensorBase):
    _show_args = {'cmap':'viridis'}
    @classmethod
    def create(cls, fn: (Path, str)) -> None:
        # fn is filename, use this method to handle how your data should be loaded
        # will need to be changed to load your data in a form your model can handle
        tens = torch.load(fn)
        #return cls(tens)
        input_tensor = torch.stack((tens, tens, tens))
        #input_tensor = tens[1:4,:,:].clone()
        #tens_center = tens[2,:,:].clone()
        #input_tensor = torch.stack((tens_center, tens_center, tens_center))
        return cls(input_tensor)
        

    def show(self, ctx=None, **kwargs):
        # defines how your tensor should be displayed
        "Show image using `merge(self._show_args, kwargs)`"
        return show_tensor_image(self, ctx=ctx, **merge(self._show_args, kwargs))


class TorchTensorMask(TensorMask):
    # defines a fast.ai TransformBlock that will load torch tensors as masks within fast.ai DataBlocks
    _show_args = {'alpha':0.5, 'cmap':'tab20'}
    @classmethod
    def create(cls, fn: (Path, str)) -> None:
        # fn is filename, use this method to handle how your mask should be loaded
        tens = torch.load(fn)
        return cls(tens)
    
    def show(self, ctx=None, **kwargs):
        # defines how the mask should be displayed
        "Show image using `merge(self._show_args, kwargs)`"
        return show_tensor_image(self, ctx=ctx, **merge(self._show_args, kwargs))


def TensorImageBlock():
    # creates the torch Tensor TransformBlock using the above Transforms
    return TransformBlock(type_tfms=TorchTensorImage.create, batch_tfms=IntToFloatTensor(div=1))


def TensorMaskBlock(codes):
    # creates the torch Tensor mask TransformBlock using the other above Transforms
    return TransformBlock(type_tfms=TorchTensorMask.create, item_tfms=AddMaskCodes(codes=codes)) #, batch_tfms=IntToFloatTensor(div=1))

Edit to fix some zero div errors. Also here’s a colab link to play with the test code:

Assuming the link works, everything should run as is in colab if you can get the import to work, unfortunately that’s beyond my familiarity with colab. It might be easier to save the notebook and run it directly in Jupyter Notebook.

1 Like