Probably the final revision for awhile, I added a show method that should coerce arbitrary tensors into displayable images as long as you manage the channels. It looks like it’s ok with multiprocessing as long as your data is not saved as cuda tensors (change the loading function and you should be fine loading numpy files or whatever else, though you’ll probably explicitly need to convert to tensors before the create method returns the data). Some augmentations seem to work, like Rotate, but aug_transforms gives me some weird results and I haven’t dug deep enough to figure out which augmentations are ok and which are not.
Just pasting the relevant code for the TransformBlocks
since that’s where the only meaningful changes are.
def _fig_bounds(x):
# not sure why this is here but fast.ai uses it in their show_image method
r = x//32
return min(5, max(1,r))
def scale_tensor_to_image(im_tensor):
# scales the tensor values to fit between 0 and 1
im_min = torch.min(im_tensor)
im_max = torch.max(im_tensor)
return (255*(im_tensor - im_min)//(im_max - im_min)).type(torch.uint8)
def show_tensor_image(im, ax=None, figsize=None, title=None, ctx=None, **kwargs):
# basically fast.ai's show_image with an additional function to scale the values
# between 0 and 1 for arbitrary tensors
"Show a PIL or PyTorch image on `ax`."
# Handle pytorch axis order
if hasattrs(im, ('data','cpu','permute')):
im = im.data.cpu()
if im.shape[0]<5: im=im.permute(1,2,0)
elif not isinstance(im,np.ndarray): im=array(im)
# Handle 1-channel images
if im.shape[-1]==1: im=im[...,0]
if torch.max(im) != torch.min(im):
im = scale_tensor_to_image(im)
else:
im = im.type(torch.uint8)
ax = ifnone(ax,ctx)
if figsize is None: figsize = (_fig_bounds(im.shape[0]), _fig_bounds(im.shape[1]))
if ax is None: _,ax = plt.subplots(figsize=figsize)
ax.imshow(im, **kwargs)
if title is not None: ax.set_title(title)
ax.axis('off')
return ax
class TorchTensorImage(TensorImage):
# defines the fast.ai TransformBlock needed for torch tensors to work within fast.ai DataBlocks
# TensorImage base class seems to work best but the other two sort of work
#class TorchTensorImage(Tensor):
#class TorchTensorImage(TensorBase):
_show_args = {'cmap':'viridis'}
@classmethod
def create(cls, fn: (Path, str)) -> None:
# fn is filename, use this method to handle how your data should be loaded
# will need to be changed to load your data in a form your model can handle
tens = torch.load(fn)
#return cls(tens)
input_tensor = torch.stack((tens, tens, tens))
#input_tensor = tens[1:4,:,:].clone()
#tens_center = tens[2,:,:].clone()
#input_tensor = torch.stack((tens_center, tens_center, tens_center))
return cls(input_tensor)
def show(self, ctx=None, **kwargs):
# defines how your tensor should be displayed
"Show image using `merge(self._show_args, kwargs)`"
return show_tensor_image(self, ctx=ctx, **merge(self._show_args, kwargs))
class TorchTensorMask(TensorMask):
# defines a fast.ai TransformBlock that will load torch tensors as masks within fast.ai DataBlocks
_show_args = {'alpha':0.5, 'cmap':'tab20'}
@classmethod
def create(cls, fn: (Path, str)) -> None:
# fn is filename, use this method to handle how your mask should be loaded
tens = torch.load(fn)
return cls(tens)
def show(self, ctx=None, **kwargs):
# defines how the mask should be displayed
"Show image using `merge(self._show_args, kwargs)`"
return show_tensor_image(self, ctx=ctx, **merge(self._show_args, kwargs))
def TensorImageBlock():
# creates the torch Tensor TransformBlock using the above Transforms
return TransformBlock(type_tfms=TorchTensorImage.create, batch_tfms=IntToFloatTensor(div=1))
def TensorMaskBlock(codes):
# creates the torch Tensor mask TransformBlock using the other above Transforms
return TransformBlock(type_tfms=TorchTensorMask.create, item_tfms=AddMaskCodes(codes=codes)) #, batch_tfms=IntToFloatTensor(div=1))
Edit to fix some zero div errors. Also here’s a colab link to play with the test code:
Assuming the link works, everything should run as is in colab if you can get the import to work, unfortunately that’s beyond my familiarity with colab. It might be easier to save the notebook and run it directly in Jupyter Notebook.