I’ve created a custom transform, and am using it in a DataBlock. I’m able to visualize the data with dls.show_batch(max_n=9, figsize=(8, 8))
, so I thought I’d also be able to use an Interpretation
object to visualize the losses.
# this works
interp = Interpretation(learn)
interp.top_losses(10)
But I get an error when doing
i.plot_top_losses(10)
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
<ipython-input-25-16c9792f7fe0> in <module>
----> 1 i.plot_top_losses(10)
~/nbs/venv/lib/python3.8/site-packages/fastai/interpret.py in plot_top_losses(self, k, largest, **kwargs)
43 x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
44 if its is not None:
---> 45 plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses, **kwargs)
46 #TODO: figure out if this is needed
47 #its None means that a batch knows how to show itself as a whole, so we pass x, x1
~/nbs/venv/lib/python3.8/site-packages/fastcore/dispatch.py in __call__(self, *args, **kwargs)
116 elif self.inst is not None: f = MethodType(f, self.inst)
117 elif self.owner is not None: f = MethodType(f, self.owner)
--> 118 return f(*args, **kwargs)
119
120 def __get__(self, inst, owner):
~/nbs/venv/lib/python3.8/site-packages/fastai/interpret.py in plot_top_losses(x, y, *args, **kwargs)
12 @typedispatch
13 def plot_top_losses(x, y, *args, **kwargs):
---> 14 raise Exception(f"plot_top_losses is not implemented for {type(x)},{type(y)}")
15
16 # Cell
Exception: plot_top_losses is not implemented for <class 'fastai.torch_core.TensorImage'>,<class 'fastai.vision.core.TensorBBox'>
My guess is that I have to register a function somewhere so the type dispatcher picks it up, but that’s a guess because I haven’t yet dug deep into the type dispatch system.
Any ideas, or tips of what to look into?
Thanks!
#custom Transform
class NoLabelBBoxLabeler(Transform):
""" Bounding box labeler with no label """
def setups(self, x): noop
def decode (self, x, **kwargs):
self.bbox,self.lbls = None,None
return self._call('decodes', x, **kwargs)
def decodes(self, x:TensorBBox):
self.bbox = x
return self.bbox if self.lbls is None else LabeledBBox(self.bbox, self.lbls)
#DataBlock
block = DataBlock(
blocks=(ImageBlock, NoLabelBBoxBlock),
get_items=get_image_files,
get_y=[BBoxTruth(df)],
n_inp=1,
item_tfms=[Resize(224)])
#learner
learn = cnn_learner(dls, resnet18,
metrics=[iou],
loss_func=MSELossFlat())