Image classifier that takes bounding boxes and classifies each one

Hello,

I’ve used fast.ai to train an image classifier, and it works really well. Thanks for this great software!

My question is about running the image classifier more efficiently.

My application has a two-stage process: first an object detector (YOLO) finds all of the objects in the image, and then a classifier runs on each detected object. The classifier is given a subset of the original image, cropped to include only the detected object.

(You may ask: why doesn’t the object detector do the classification? It’s because my experimentation has found the system performs better if object detection and object classification use separate models.)

At the moment, I’m doing classification serially: I simply loop over all the detected objects, pass a cropped image to fast.ai, and run the classification inference. The problem is, this is horribly inefficient, especially for images with many detected objects. I’m wondering whether I can classify all detected objects in a single inference.

So my question is: how can I design a fast.ai image classifier so that it takes an input image, plus a list of bounding boxes, and returns the classification of each bounding box — in a single model?

Surely this isn’t an uncommon situation, right?

Thanks!
Adrian

1 Like

Hello,
working on a very similar project here : i tried playing around with datablocks but didn’t really achieve anything (my experience is that the type_tfms of the ImageBlock will ‘force’ you to have only the image file name as input thus making the second input which is the bounding boxes useless.)

I tried to follow this guide (fastai - Mid-tier data API - Pets) to create the dataloader using only TfmdList which let you customize the data entry point so i have as entry a zip of 3 list :

  • image filename
  • bounding boxes associated
  • label of the image (in my case i want the label of the whole image to be passed as label to the image crop, this is a specific of my project and may not be the same for you)

The dataloader looks like it’s working with show_batch() giving the right output, however when i try to feed it to a Learner the valid_loss and metric_error both return None values. I have currently no clues about where do this come from.

Please let me know if you found a simpler solution and/or if you encoutered a similar issue with the Learner returning None values for the validation set.