I wrote a transform block for an object detection task that has only one class. Since there’s only one class, I don’t need any classification or labeling on the bounding box. I didn’t get this working out of the box with the fastai applications level API, but once I more deeply learned the transform system, I was able to figure it out. Sharing it here in case it’s useful for someone else:
from fastcore.transform import *
from fastai.data.all import *
from fastai.vision.all import *
from fastai.test_utils import synth_learner
import pandas as pd
from pathlib import Path
from nbdev import show_doc
import os
from chessocr import *
#export
class NoLabelBBoxLabeler(Transform):
""" Bounding box labeler with no label """
def setups(self, x): noop
def decode (self, x, **kwargs):
self.bbox,self.lbls = None,None
return self._call('decodes', x, **kwargs)
def decodes(self, x:TensorBBox):
self.bbox = x
return self.bbox if self.lbls is None else LabeledBBox(self.bbox, self.lbls)
#export
class BBoxTruth:
""" get bounding box location from DataFrame """
def __init__(self, df): self.df=df
def __call__(self, o):
size,x,y,_,_,_,_=self.df.iloc[int(o.stem)-1]
return [[x,y, x+size, y+size]]
#export
def iou(pred, target):
""" Vectorized Intersection Over Union calculation """
target = Tensor.cpu(target).squeeze(1)
pred = Tensor.cpu(pred)
ab = np.stack([pred, target])
intersect_area = np.maximum(ab[:, :, [2, 3]].min(axis=0) - ab[:, :, [0, 1]].max(axis=0), 0).prod(axis=1)
union_area = ((ab[:, :, 2] - ab[:, :, 0]) * (ab[:, :, 3] - ab[:, :, 1])).sum(axis=0) - intersect_area
return (intersect_area / union_area).mean()
#export
NoLabelBBoxBlock = TransformBlock(type_tfms=TensorBBox.create,
item_tfms=[PointScaler, NoLabelBBoxLabeler])
data_url = Path.home()/".fastai/data/chess"
df = pd.read_csv(data_url/'annotations.csv', index_col=0)
block = DataBlock(
blocks=(ImageBlock, NoLabelBBoxBlock),
get_items=get_image_files,
get_y=[BBoxTruth(df)],
n_inp=1,
item_tfms=[Resize(224)])
dls=block.dataloaders(data_url, batch_size=64)
dls.show_batch(max_n=9, figsize=(8, 8))