Cumstom ItemList

I follow the guide:“https://docs.fast.ai/tutorial.itemlist.html” to cumstom my ImageTupleList,All the steps are working fine except for transform.I use the same code as follow:

    def apply_tfms(self, tfms, **kwargs):
        print('img1')
        print(self.img1)
        self.img1 = self.img1.apply_tfms(tfms, **kwargs)
        self.img2 = self.img2.apply_tfms(tfms, **kwargs)
        self.data = [-1+2*self.img1.data,-1+2*self.img2.data]
        return self

but I get an error below,I also tried to make my ImageTuple inherit from the Image class, but it doesn’t seem to be a correct approach.
I customized my own dataset in order to use fastai’s transform,How can I solve this problem?

I think its related to how you are creating GA_list, post the relevant code

Here is my code:
firstly, I Creating a custom ItemBase subclass:

class ImageTuple(ItemBase):
    def __init__(self, img1, img2):
        self.img1,self.img2 = img1,img2
        self.obj,self.data = (img1,img2),[-1+2*img1.data,-1+2*img2.data]
        print(self.obj)
    def __str__(self):
        #return f'{self.obj}'
        return f'{self.obj}'
    def apply_tfms(self, tfms, **kwargs):
        self.img1 = self.img1.apply_tfms(tfms, **kwargs)
        self.img2 = self.img2.apply_tfms(tfms, **kwargs)
        self.data = [-1+2*self.img1.data,-1+2*self.img2.data]
        return self
    def to_one(self): return Image(0.5+torch.cat(self.data,2)/2)

Secondly,I Creating a custom ItemList subclass:

class ImageTupleList(ImageList):
    _label_cls=pfmList
    def __init__(self, items, itemsB=None, **kwargs):
        super().__init__(items, **kwargs)
        self.itemsB = itemsB
        self.copy_new.append('itemsB')
    
    def get(self, i):
        #ImageList的get方法,得到img1是Image
        img1 = super().get(i)
        fn = self.itemsB[i]
        return ImageTuple(img1, open_image(fn))
    
    def reconstruct(self, t:Tensor): 
        return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))
    
    @classmethod
    def from_folders(cls, path, folderA, folderB, **kwargs):
        #读文件夹folderB的名字
        itemsB = ImageList.from_folder(path/folderB).items
        res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs)
        res.path = path
        return res
    @classmethod
    def from_custom(cls, path, folderA, itemsB, **kwargs):
        res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs).filter_by_func(lambda o:o.parts[-2]=='left')
        res.path = path
        return res
    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(12,6), **kwargs):
        #"Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            xs[i].to_one().show(ax=ax, **kwargs)
        plt.tight_layout()

    def show_xyzs(self, xs, ys, zs, figsize:Tuple[int,int]=None, **kwargs):
        #"""Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`.
        #`kwargs` are passed to the show method."""
        figsize = ifnone(figsize, (12,3*len(xs)))
        fig,axs = plt.subplots(len(xs), 2, figsize=figsize)
        fig.suptitle('Ground truth / Predictions', weight='bold', size=14)
        for i,(x,z) in enumerate(zip(xs,zs)):
            x.to_one().show(ax=axs[i,0], **kwargs)
            z.to_one().show(ax=axs[i,1], **kwargs)

And my pfmList is:

class pfmList(ItemList):
    # read an .pfm file into numpy array, used to load SceneFlow disparity files
    #类内函数的第一个参数为self,否则会报错,因为调用self.pfm_imread()会自动传入self参数,即pfm_read(self,filename)导致参数不匹配
    def pfm_imread(self,filename):
        file = open(filename, 'rb')
        color = None
        width = None
        height = None
        scale = None
        endian = None

        header = file.readline().decode('utf-8').rstrip()
        if header == 'PF':
            color = True
        elif header == 'Pf':
            color = False
        else:
            raise Exception('Not a PFM file.')

        dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
        if dim_match:
            width, height = map(int, dim_match.groups())
        else:
            raise Exception('Malformed PFM header.')

        scale = float(file.readline().rstrip())
        if scale < 0:  # little-endian
            endian = '<'
            scale = -scale
        else:
            endian = '>'  # big-endian

        data = np.fromfile(file, endian + 'f')
        shape = (height, width, 3) if color else (height, width)

        data = np.reshape(data, shape)
        data = np.flipud(data)
        return data, scale
    def get(self, i):
        #得到(height*width)的视差
        fn = super().get(i)
        #/data/home/xubin/dataset/dataset/GANet_dataset/disparity/TRAIN/15mm_focallength/scene_backwards/slow/left/0058.png
        #print(fn)
        disparity, scale = self.pfm_imread(str(fn))
        disparity = np.maximum(disparity, 0)
        #扩展维度
        disparity = np.expand_dims(disparity, 0)
        #将numpy转为Image
        img1 = torch.from_numpy(disparity) 
        
        return Image(img1)
    def reconstruct(self, t:Tensor): return Image(t.float())

Through the above function,I Can get my GA_list as below:


Finally,I want to change my GA_list to databunch,but when I use transform, I get the error described earlier.I am confused because my img1 and img2 are indeed Image classes, they have the apply_tfms attribute.I haven’t found my problem yet. If you know my mistake, please tell me.

You are still not showing how you create Ga_list, how are you passing the transforms

Sorry,I thought you needed to look at the code of the item I built.And the code I create for Ga_list is shown below:

list_right = ItemList.from_folder(GA_path/'frames_finalpass').filter_by_func(lambda o:o.parts[-2]=='right')
itemsB = list_right.items
GA_list = ImageTupleList.from_custom(GA_path/'frames_finalpass','',itemsB)
GA_list = GA_list.split_by_folder(train='TRAIN',valid='TEST')
def get_y_fn(x):
    #x = x[0]
    temp = str(x)
    temp = temp.replace('frames_finalpass','disparity')
    temp = temp.replace('png','pfm')
    return Path(temp)
GA_list = GA_list.label_from_func(get_y_fn)

I can get the GA_list as below:
f3e28b2c79b67f6ab47a8137759cc8e3.pngf3e28b2c79b67f6ab47a8137759cc8e3.png
So I think the problem arises when I pass the transform.I solved this problem yesterday and just changed the code order.
The orgin code:

planet_tfms = get_transforms()
data = (GA_list
        .databunch(bs=2)
        .transform(planet_tfms)
        )

and now change to this one:

planet_tfms = get_transforms()
data = (GA_list.transform(planet_tfms)
        .databunch(bs=2)
        )

Now it works!Thank you for your tips to solve my problem.

1 Like