There’s likely a better way to do this but here’s the implementation I came up with to load some dummy tensors.
from fastai.vision.all import *
def get_mask(fname, masklabel, tensorlabel):
path_pieces = fname.parts
new_path = []
for i in range(len(path_pieces)):
if path_pieces[i] == tensorlabel:
new_path.append(masklabel)
continue
new_path.append(path_pieces[i])
mask_path = os.sep.join(new_path)
return Path(mask_path)
def MaskGetter(masklabel=None, tensorlabel=None):
"Create `get_mask` partial function that replaces the folder tensorlabel with the folder masklabel and returns the path"
def _inner(o, masklabel=masklabel, tensorlabel=tensorlabel):
return get_mask(o, masklabel, tensorlabel)
return _inner
def _parent_idxs(items, name):
def _inner(items, name): return mask2idxs(Path(o).parent.name == name for o in items)
return [i for n in L(name) for i in _inner(items, n)]
def ParentSplitter(train_name='train', valid_name='valid'):
"Split `items` from the parent folder names (`train_name` and `valid_name`)."
def _inner(o):
return _parent_idxs(o, train_name), _parent_idxs(o, valid_name)
return _inner
class TensorLoad(Transform):
def __init__(self):
pass
def encodes(self, o):
return torch.load(o)
def TensorBlock():
return TransformBlock(type_tfms=TensorLoad, batch_tfms=IntToFloatTensor)
def TensorMaskBlock(codes=None):
return TransformBlock(type_tfms=TensorLoad, item_tfms=AddMaskCodes(codes=codes), batch_tfms=IntToFloatTensor)
def torch_tensor_creator():
tensor_list = ["tensor_1", "tensor_2", "tensor_3", "tensor_4", "tensor_5", "tensor_6", "tensor_7", "tensor_8"]
train_tensor_path = r"C:\Users\Joe\Downloads\torch_tensor_test\tensors\train"
valid_tensor_path = r"C:\Users\Joe\Downloads\torch_tensor_test\tensors\valid"
train_mask_path = r"C:\Users\Joe\Downloads\torch_tensor_test\masks\train"
valid_mask_path = r"C:\Users\Joe\Downloads\torch_tensor_test\masks\valid"
for tensor_name in tensor_list:
train_tensor = torch.rand((128, 128), dtype=torch.float32)
valid_tensor = torch.rand((128, 128), dtype=torch.float32)
train_mask = torch.randint(high=1, size=(128, 128), dtype=torch.float32)
valid_mask = torch.randint(high=1, size=(128, 128), dtype=torch.float32)
torch.save(train_tensor, os.path.join(train_tensor_path, tensor_name + '.pt'))
torch.save(valid_tensor, os.path.join(valid_tensor_path, tensor_name + '.pt'))
torch.save(train_mask, os.path.join(train_mask_path, tensor_name + '.pt'))
torch.save(valid_mask, os.path.join(valid_mask_path, tensor_name + '.pt'))
return None
if __name__ == '__main__':
path = r"C:\Users\Joe\Downloads\torch_tensor_test"
datablock = DataBlock(
blocks=(TensorBlock, TensorMaskBlock(codes=['ones'])),
get_items=FileGetter(extensions='.pt', folders=['tensors']),
get_y=MaskGetter(masklabel='masks', tensorlabel='tensors'),
splitter=ParentSplitter(train_name='train', valid_name='valid'))
dataloaders = datablock.dataloaders(Path(path), bs=2)
print(dataloaders.one_batch())