ImageDataBunch from folder with saved torch tensors

(Sam) #1

I have my data setup in the /train/label way, however, the images are saved as .pt torch tensors because they are higher dimension (> 3 layers/channels). Does anyone have experience creating an ImageDataBunch from torch tensors with higher dimensions?

This is the error I get if I just try to use ImageDataBunch.from_folder(path, etc)
IndexError: index 0 is out of bounds for axis 0 with size 0

0 Likes

(Janne Mäyrä) #2

I’ve worked with those, and this should work. You can use this like

data = (HSImageItemList.from_df(df, path=datapath, folder=None, dims=3, chans=chans, cols='filename', suffix='')
        .split_from_df('is_valid')
        .label_from_df('species')
        .databunch(bs=128)).normalize(stats[:,chans])

Most likely you can do with just subclass some itemlist class and replace the open and/or get methods. Tutorial in documents helps a lot: https://docs.fast.ai/tutorial.itemlist

Anyway, here’s the code

def open_hsimage(fn:PathOrStr, cls:type=HsImage, dims=3, chans=list(range(461)))->HsImage:
    "Data is saved as numpy array"
    im = torch.from_numpy(np.load(fn))[chans]
    if dims == 3: im = im[None]
    return cls(im)

class HsImage(Image):
    "Custom class for HSImage, only modifies the show method"
    def __init__(self, px:Tensor):
        super().__init__(px)

    def show(...):
    """Here some options for different types of visualization, 
        such as average spectra and rgb renders
    """

class HSDataBunch(ImageDataBunch):
    """
    Subclassing ImageDataBunch because normalize is defined by it
    Also modifies show_batch to work with HsImages
    """
    def show_batch(...):
         #Stuff

class HSImageItemList(ItemList):
    """
    Custom ItemList for N-dimensional images either as image or volumetric data. 
    Plotting utilities also added, as well as option to specify which channels to use

    """
    _bunch = HSDataBunch
    _square_show = True
    def __init__(self, items, dims=3, chans=list(range(461)), **kwargs):
        super().__init__(items, **kwargs)
        self.dims = dims
        self.chans = chans
        self.copy_new.append('dims')
        self.copy_new.append('chans')

    def open(self, fn)->HsImage:
        return open_hsimage(fn, dims=self.dims, chans=self.chans)

    def get(self, i)->HsImage:
        fn = super().get(i)
        res = self.open(fn)
        return res

    def reconstruct(self, t:Tensor)->HsImage:
        return HsImage(t.float())

    @classmethod
    def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, folder:PathOrStr=None, suffix:str='.npy', **kwargs)->ItemList:
        suffix = suffix or ''
        res = super().from_df(df, path=path, cols=cols, **kwargs)
        pref = f'{res.path}{os.path.sep}'
        if folder is not None: pref += f'{folder}{os.path.sep}'
        res.items = np.char.add(np.char.add(pref, res.items.astype(str)), suffix)
        return res

    def show_xys(...):
          #stuff

    def show_xyzs(...):
          #stuff
0 Likes