Custom targets itemList with multiple masks

(Marcin) #1

Hello there,

After finishing the amazing fast.ai course and getting myself familiar with all the “basic” problems that the library is designed for, I went a bit further and decided to try to use it for problems that are not covered by current api.
I stumbled upon an issue when creating custom item list for my targets. The source is a single image but targets consist of several masks (six to be precise). I tried to replace get method from standard ImageList with my own implementation where I have a custom method (“get_maps”) which does the following:
given the path to a single annotation file - a *.txt file with coordinates for creating masks, it creates and returns 6 masks.
I followed https://docs.fast.ai/tutorial.itemlist.html tutorial for creating a custom itemList but got stuck on the following error message.
I would appreciate any help in pointing me into the right direction.

My custom base item:

 class MapsItem(ItemBase):
def __init__(self, scmap, geo_b, geo_l, geo_t, geo_r, geo_th):
    
    self.scmap, self.geo_b, self.geo_l, self.geo_t, self.geo_r, self.geo_th = scmap, geo_b, geo_l, geo_t, geo_r, geo_th
    
    self.obj, self.data = (scmap, geo_b, geo_l, geo_t, geo_r, geo_th), [scmap.data, geo_b.data, geo_l, geo_t.data, geo_r.data, geo_th.data]
    
# apply data augmentation, that's done by writing and apply_tfms method
def apply_tfms(self, tfms, **kwargs):
    self.score_map = self.scmap.apply_tfms(tfms, **kwargs)
    self.geo_map_off_bottom = self.geo_b.apply_tfms(tfms, **kwargs)
    self.geo_map_off_left = self.geo_l.apply_tfms(tfms, **kwargs)
    self.geo_map_off_top = self.geo_t.apply_tfms(tfms, **kwargs)
    self.geo_map_off_right = self.geo_r.apply_tfms(tfms, **kwargs)
    self.geo_map_theta = self.geo_th.apply_tfms(tfms, **kwargs)
    self.data = [scmap.data, geo_b.data, geo_l.data, geo_t.data, geo_r.data, geo_th.data]
    return self

# method to stack the masks next ot each other, 
# which we will use later for a customized show_batch/ show_results behavior
def to_one(self): 
    # normalize for vis
    data_vis = [scmap.data.type(torch.FloatTensor), transforms.Normalize([0.5], torch.max(geo_b.data.type(torch.FloatTensor))])
                (geo_b.data.type(torch.FloatTensor)), transforms.Normalize([0.5], [torch.max(geo_l.data.type(torch.FloatTensor))])
                (geo_l.data.type(torch.FloatTensor)), transforms.Normalize([0.5], [torch.max(geo_t.data.type(torch.FloatTensor))])
                (geo_t.data.type(torch.FloatTensor)), transforms.Normalize([0.5], [torch.max(geo_r.data.type(torch.FloatTensor))])
                (geo_r.data.type(torch.FloatTensor)),
                 geo_th.data.type(torch.FloatTensor)
               ]

    return ImageSegment(torch.cat(data_vis,2))

My custom target list:

    class TargetMapList(ImageList):        

def get(self, i):
    filename = self.items[i]
    scmap, geo_b, geo_l, geo_t, geo_r, geo_th = get_maps(filename)
    return MapsItem(scmap, geo_b, geo_l, geo_t, geo_r, geo_th)

def reconstruct(self, t:Tensor):
    return MapsItem(ImageSegment(t[0]),ImageSegment(t[1]),ImageSegment(t[2]),ImageSegment(t[3]), ImageSegment(t[4]),ImageSegment(t[5]))

Finally when I call:
map_list = TargetMapList(txt_fnames)
map_list

where txt_fnames is a list of filenames of txt files

this results in an error:

    ---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
c:\users\lukszam\env_artworks\lib\site-packages\IPython\core\formatters.py in __call__(self, obj)
    700                 type_pprinters=self.type_printers,
    701                 deferred_pprinters=self.deferred_printers)
--> 702             printer.pretty(obj)
    703             printer.flush()
    704             return stream.getvalue()

c:\users\lukszam\env_artworks\lib\site-packages\IPython\lib\pretty.py in pretty(self, obj)
    400                         if cls is not object \
    401                                 and callable(cls.__dict__.get('__repr__')):
--> 402                             return _repr_pprint(obj, self, cycle)
    403 
    404             return _default_pprint(obj, self, cycle)

c:\users\lukszam\env_artworks\lib\site-packages\IPython\lib\pretty.py in _repr_pprint(obj, p, cycle)
    695     """A pprint that just redirects to the normal repr function."""
    696     # Find newlines and replace them with p.break_()
--> 697     output = repr(obj)
    698     for idx,output_line in enumerate(output.splitlines()):
    699         if idx:

c:\users\lukszam\env_artworks\lib\site-packages\fastai\data_block.py in __repr__(self)
     67     def __repr__(self)->str:
     68         items = [self[i] for i in range(min(5,len(self.items)))]
---> 69         return f'{self.__class__.__name__} ({len(self.items)} items)\n{show_some(items)}\nPath: {self.path}'
     70 
     71     def process(self, processor:PreProcessors=None):

c:\users\lukszam\env_artworks\lib\site-packages\fastai\core.py in show_some(items, n_max, sep)
    341     "Return the representation of the first  `n_max` elements in `items`."
    342     if items is None or len(items) == 0: return ''
--> 343     res = sep.join([f'{o}' for o in items[:n_max]])
    344     if len(items) > n_max: res += '...'
    345     return res

c:\users\lukszam\env_artworks\lib\site-packages\fastai\core.py in <listcomp>(.0)
    341     "Return the representation of the first  `n_max` elements in `items`."
    342     if items is None or len(items) == 0: return ''
--> 343     res = sep.join([f'{o}' for o in items[:n_max]])
    344     if len(items) > n_max: res += '...'
    345     return res

c:\users\lukszam\env_artworks\lib\site-packages\fastai\core.py in __repr__(self)
    154     "Base item type in the fastai library."
    155     def __init__(self, data:Any): self.data=self.obj=data
--> 156     def __repr__(self)->str: return f'{self.__class__.__name__} {str(self)}'
    157     def show(self, ax:plt.Axes, **kwargs):
    158         "Subclass this method if you want to customize the way this `ItemBase` is shown on `ax`."

... last 1 frames repeated, from the frame below ...

c:\users\lukszam\env_artworks\lib\site-packages\fastai\core.py in __repr__(self)
    154     "Base item type in the fastai library."
    155     def __init__(self, data:Any): self.data=self.obj=data
--> 156     def __repr__(self)->str: return f'{self.__class__.__name__} {str(self)}'
    157     def show(self, ax:plt.Axes, **kwargs):
    158         "Subclass this method if you want to customize the way this `ItemBase` is shown on `ax`."

RecursionError: maximum recursion depth exceeded while calling a Python object

Cheers,
Marcin

0 Likes

#2

You need to define an __str__ method for your MapsItem class. It can be something as simple as:

def __str__(self):
    return f'{self.obj}, {self.data}'
0 Likes

(Marcin) #3

This solved my problem! Thanks a lot @florobax ! :grinning:
I wonder why it’s not mentioned in custom item tutorial.

1 Like

#4

This is actually mentioned, though briefly:

Those are the more important attributes your custom ItemBase needs as they’re used everywhere in the fastai library:

  • ItemBase.data is the thing that is passed to pytorch when you want to create a DataLoader . This is what needs to be fed to your model. Note that it might be different from the representation of your item since you might want something that is more understandable.
  • __str__ representation: if applicable, this is what will be displayed when the fastai library has to show your item.
0 Likes

(Marcin) #5

You are right, I must have missed that!

0 Likes