Loading Torch Tensors into the DataBlock API

There’s likely a better way to do this but here’s the implementation I came up with to load some dummy tensors.

from fastai.vision.all import *


def get_mask(fname, masklabel, tensorlabel):
    path_pieces = fname.parts
    new_path = []
    for i in range(len(path_pieces)):
        if path_pieces[i] == tensorlabel:
            new_path.append(masklabel)
            continue
        new_path.append(path_pieces[i])
    mask_path = os.sep.join(new_path)
    return Path(mask_path)


def MaskGetter(masklabel=None, tensorlabel=None):
    "Create `get_mask` partial function that replaces the folder tensorlabel with the folder masklabel and returns the path"
    def _inner(o, masklabel=masklabel, tensorlabel=tensorlabel):
        return get_mask(o, masklabel, tensorlabel)
    return _inner


def _parent_idxs(items, name):
    def _inner(items, name): return mask2idxs(Path(o).parent.name == name for o in items)
    return [i for n in L(name) for i in _inner(items, n)]


def ParentSplitter(train_name='train', valid_name='valid'):
    "Split `items` from the parent folder names (`train_name` and `valid_name`)."
    def _inner(o):
        return _parent_idxs(o, train_name), _parent_idxs(o, valid_name)
    return _inner


class TensorLoad(Transform):
    def __init__(self):
        pass

    def encodes(self, o):
        return torch.load(o)


def TensorBlock():
    return TransformBlock(type_tfms=TensorLoad, batch_tfms=IntToFloatTensor)


def TensorMaskBlock(codes=None):
    return TransformBlock(type_tfms=TensorLoad, item_tfms=AddMaskCodes(codes=codes), batch_tfms=IntToFloatTensor)


def torch_tensor_creator():
    tensor_list = ["tensor_1", "tensor_2", "tensor_3", "tensor_4", "tensor_5", "tensor_6", "tensor_7", "tensor_8"]
    train_tensor_path = r"C:\Users\Joe\Downloads\torch_tensor_test\tensors\train"
    valid_tensor_path = r"C:\Users\Joe\Downloads\torch_tensor_test\tensors\valid"
    train_mask_path = r"C:\Users\Joe\Downloads\torch_tensor_test\masks\train"
    valid_mask_path = r"C:\Users\Joe\Downloads\torch_tensor_test\masks\valid"

    for tensor_name in tensor_list:
        train_tensor = torch.rand((128, 128), dtype=torch.float32)
        valid_tensor = torch.rand((128, 128), dtype=torch.float32)
        train_mask = torch.randint(high=1, size=(128, 128), dtype=torch.float32)
        valid_mask = torch.randint(high=1, size=(128, 128), dtype=torch.float32)

        torch.save(train_tensor, os.path.join(train_tensor_path, tensor_name + '.pt'))
        torch.save(valid_tensor, os.path.join(valid_tensor_path, tensor_name + '.pt'))
        torch.save(train_mask, os.path.join(train_mask_path, tensor_name + '.pt'))
        torch.save(valid_mask, os.path.join(valid_mask_path, tensor_name + '.pt'))

    return None


if __name__ == '__main__':
    path = r"C:\Users\Joe\Downloads\torch_tensor_test"

    datablock = DataBlock(
                    blocks=(TensorBlock, TensorMaskBlock(codes=['ones'])),
                    get_items=FileGetter(extensions='.pt', folders=['tensors']),
                    get_y=MaskGetter(masklabel='masks', tensorlabel='tensors'),
                    splitter=ParentSplitter(train_name='train', valid_name='valid'))

    dataloaders = datablock.dataloaders(Path(path), bs=2)
    print(dataloaders.one_batch())