Visualize top losses with custom Transform

I’ve created a custom transform, and am using it in a DataBlock. I’m able to visualize the data with dls.show_batch(max_n=9, figsize=(8, 8)), so I thought I’d also be able to use an Interpretation object to visualize the losses.

# this works
interp = Interpretation(learn)
interp.top_losses(10)

But I get an error when doing

i.plot_top_losses(10)

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-25-16c9792f7fe0> in <module>
----> 1 i.plot_top_losses(10)

~/nbs/venv/lib/python3.8/site-packages/fastai/interpret.py in plot_top_losses(self, k, largest, **kwargs)
     43         x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
     44         if its is not None:
---> 45             plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses,  **kwargs)
     46         #TODO: figure out if this is needed
     47         #its None means that a batch knows how to show itself as a whole, so we pass x, x1

~/nbs/venv/lib/python3.8/site-packages/fastcore/dispatch.py in __call__(self, *args, **kwargs)
    116         elif self.inst is not None: f = MethodType(f, self.inst)
    117         elif self.owner is not None: f = MethodType(f, self.owner)
--> 118         return f(*args, **kwargs)
    119 
    120     def __get__(self, inst, owner):

~/nbs/venv/lib/python3.8/site-packages/fastai/interpret.py in plot_top_losses(x, y, *args, **kwargs)
     12 @typedispatch
     13 def plot_top_losses(x, y, *args, **kwargs):
---> 14     raise Exception(f"plot_top_losses is not implemented for {type(x)},{type(y)}")
     15 
     16 # Cell

Exception: plot_top_losses is not implemented for <class 'fastai.torch_core.TensorImage'>,<class 'fastai.vision.core.TensorBBox'>

My guess is that I have to register a function somewhere so the type dispatcher picks it up, but that’s a guess because I haven’t yet dug deep into the type dispatch system.

Any ideas, or tips of what to look into?

Thanks!

#custom Transform
class NoLabelBBoxLabeler(Transform):
    """ Bounding box labeler with no label """
    def setups(self, x): noop
    def decode (self, x, **kwargs):
        self.bbox,self.lbls = None,None
        return self._call('decodes', x, **kwargs)

    def decodes(self, x:TensorBBox):
        self.bbox = x
        return self.bbox if self.lbls is None else LabeledBBox(self.bbox, self.lbls)
#DataBlock
block = DataBlock(
    blocks=(ImageBlock, NoLabelBBoxBlock), 
    get_items=get_image_files,
    get_y=[BBoxTruth(df)],
    n_inp=1,
    item_tfms=[Resize(224)])
#learner
learn = cnn_learner(dls, resnet18, 
metrics=[iou], 
loss_func=MSELossFlat())

I managed to figure it out after reading the documentation about type dispatch. I needed to register a new function to handle the specific types I was using. Then that function gets called by plot_top_losses and it receives all the information it needs to create the plot.

In other news, reading the documentation really helps, especially for fastcore.

@typedispatch
def plot_top_losses(x: TensorImage, y: TensorBBox, samples, outs, raws, losses):
    pass
2 Likes