Hello there,
After finishing the amazing fast.ai course and getting myself familiar with all the “basic” problems that the library is designed for, I went a bit further and decided to try to use it for problems that are not covered by current api.
I stumbled upon an issue when creating custom item list for my targets. The source is a single image but targets consist of several masks (six to be precise). I tried to replace get method from standard ImageList with my own implementation where I have a custom method (“get_maps”) which does the following:
given the path to a single annotation file - a *.txt file with coordinates for creating masks, it creates and returns 6 masks.
I followed https://docs.fast.ai/tutorial.itemlist.html tutorial for creating a custom itemList but got stuck on the following error message.
I would appreciate any help in pointing me into the right direction.
My custom base item:
class MapsItem(ItemBase):
def __init__(self, scmap, geo_b, geo_l, geo_t, geo_r, geo_th):
self.scmap, self.geo_b, self.geo_l, self.geo_t, self.geo_r, self.geo_th = scmap, geo_b, geo_l, geo_t, geo_r, geo_th
self.obj, self.data = (scmap, geo_b, geo_l, geo_t, geo_r, geo_th), [scmap.data, geo_b.data, geo_l, geo_t.data, geo_r.data, geo_th.data]
# apply data augmentation, that's done by writing and apply_tfms method
def apply_tfms(self, tfms, **kwargs):
self.score_map = self.scmap.apply_tfms(tfms, **kwargs)
self.geo_map_off_bottom = self.geo_b.apply_tfms(tfms, **kwargs)
self.geo_map_off_left = self.geo_l.apply_tfms(tfms, **kwargs)
self.geo_map_off_top = self.geo_t.apply_tfms(tfms, **kwargs)
self.geo_map_off_right = self.geo_r.apply_tfms(tfms, **kwargs)
self.geo_map_theta = self.geo_th.apply_tfms(tfms, **kwargs)
self.data = [scmap.data, geo_b.data, geo_l.data, geo_t.data, geo_r.data, geo_th.data]
return self
# method to stack the masks next ot each other,
# which we will use later for a customized show_batch/ show_results behavior
def to_one(self):
# normalize for vis
data_vis = [scmap.data.type(torch.FloatTensor), transforms.Normalize([0.5], torch.max(geo_b.data.type(torch.FloatTensor))])
(geo_b.data.type(torch.FloatTensor)), transforms.Normalize([0.5], [torch.max(geo_l.data.type(torch.FloatTensor))])
(geo_l.data.type(torch.FloatTensor)), transforms.Normalize([0.5], [torch.max(geo_t.data.type(torch.FloatTensor))])
(geo_t.data.type(torch.FloatTensor)), transforms.Normalize([0.5], [torch.max(geo_r.data.type(torch.FloatTensor))])
(geo_r.data.type(torch.FloatTensor)),
geo_th.data.type(torch.FloatTensor)
]
return ImageSegment(torch.cat(data_vis,2))
My custom target list:
class TargetMapList(ImageList):
def get(self, i):
filename = self.items[i]
scmap, geo_b, geo_l, geo_t, geo_r, geo_th = get_maps(filename)
return MapsItem(scmap, geo_b, geo_l, geo_t, geo_r, geo_th)
def reconstruct(self, t:Tensor):
return MapsItem(ImageSegment(t[0]),ImageSegment(t[1]),ImageSegment(t[2]),ImageSegment(t[3]), ImageSegment(t[4]),ImageSegment(t[5]))
Finally when I call:
map_list = TargetMapList(txt_fnames)
map_list
where txt_fnames is a list of filenames of txt files
this results in an error:
---------------------------------------------------------------------------
RecursionError Traceback (most recent call last)
c:\users\lukszam\env_artworks\lib\site-packages\IPython\core\formatters.py in __call__(self, obj)
700 type_pprinters=self.type_printers,
701 deferred_pprinters=self.deferred_printers)
--> 702 printer.pretty(obj)
703 printer.flush()
704 return stream.getvalue()
c:\users\lukszam\env_artworks\lib\site-packages\IPython\lib\pretty.py in pretty(self, obj)
400 if cls is not object \
401 and callable(cls.__dict__.get('__repr__')):
--> 402 return _repr_pprint(obj, self, cycle)
403
404 return _default_pprint(obj, self, cycle)
c:\users\lukszam\env_artworks\lib\site-packages\IPython\lib\pretty.py in _repr_pprint(obj, p, cycle)
695 """A pprint that just redirects to the normal repr function."""
696 # Find newlines and replace them with p.break_()
--> 697 output = repr(obj)
698 for idx,output_line in enumerate(output.splitlines()):
699 if idx:
c:\users\lukszam\env_artworks\lib\site-packages\fastai\data_block.py in __repr__(self)
67 def __repr__(self)->str:
68 items = [self[i] for i in range(min(5,len(self.items)))]
---> 69 return f'{self.__class__.__name__} ({len(self.items)} items)\n{show_some(items)}\nPath: {self.path}'
70
71 def process(self, processor:PreProcessors=None):
c:\users\lukszam\env_artworks\lib\site-packages\fastai\core.py in show_some(items, n_max, sep)
341 "Return the representation of the first `n_max` elements in `items`."
342 if items is None or len(items) == 0: return ''
--> 343 res = sep.join([f'{o}' for o in items[:n_max]])
344 if len(items) > n_max: res += '...'
345 return res
c:\users\lukszam\env_artworks\lib\site-packages\fastai\core.py in <listcomp>(.0)
341 "Return the representation of the first `n_max` elements in `items`."
342 if items is None or len(items) == 0: return ''
--> 343 res = sep.join([f'{o}' for o in items[:n_max]])
344 if len(items) > n_max: res += '...'
345 return res
c:\users\lukszam\env_artworks\lib\site-packages\fastai\core.py in __repr__(self)
154 "Base item type in the fastai library."
155 def __init__(self, data:Any): self.data=self.obj=data
--> 156 def __repr__(self)->str: return f'{self.__class__.__name__} {str(self)}'
157 def show(self, ax:plt.Axes, **kwargs):
158 "Subclass this method if you want to customize the way this `ItemBase` is shown on `ax`."
... last 1 frames repeated, from the frame below ...
c:\users\lukszam\env_artworks\lib\site-packages\fastai\core.py in __repr__(self)
154 "Base item type in the fastai library."
155 def __init__(self, data:Any): self.data=self.obj=data
--> 156 def __repr__(self)->str: return f'{self.__class__.__name__} {str(self)}'
157 def show(self, ax:plt.Axes, **kwargs):
158 "Subclass this method if you want to customize the way this `ItemBase` is shown on `ax`."
RecursionError: maximum recursion depth exceeded while calling a Python object
Cheers,
Marcin