I’m looking to make my own custom collate function for single image points. For example, the images in my data set have an object whose center is marked with a single red dot. Since the Fastai library appears to only have a collate function for bounding boxes, this is why I’m wanting to make my own. I’m trying to understand the source code for bb_pad_collate, though I’m having some trouble understanding how the whole function comes together/what every line does. Here is the source code for bb_pad_collate:
def bb_pad_collate(samples:BatchSamples, pad_idx:int=0) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]:
"Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`."
max_len = max([len(s[1].data[1]) for s in samples])
bboxes = torch.zeros(len(samples), max_len, 4)
labels = torch.zeros(len(samples), max_len).long() + pad_idx
imgs = []
for i,s in enumerate(samples):
imgs.append(s[0].data[None])
bbs, lbls = s[1].data
bboxes[i,-len(lbls):] = bbs
labels[i,-len(lbls):] = lbls
return torch.cat(imgs,0), (bboxes,labels)
I’d appreciate it if anyone could give a brief summary of how this function works, and if possible some steps to take for creating an image points collate function based off of bb_pad_collate.
Sort of. For my collate function I just removed images where the point ended up outside the actual image so the model wouldn’t get confused. Here’s the function that I wrote:
def ip_pad_collate(samples:BatchSamples):
samples = to_data(samples)
count = 1
while count < len(samples):
for i, sample in enumerate(samples):
if (sample[1].shape[0] == 0):
del samples[i]
count+=1
return tensor([s[0].numpy() for s in samples]), tensor([s[1].numpy() for s in samples])
Basically, I look through every sample in my image batch and check to see if the image point is outside the image size. If it is, I just delete it from the batch. Here’s the working notebook of it in action.
@waydegg I see at one point you ran into a similar issue that I have regarding the need to modify the collate function in order to get your model working. Would you have any insight into how I can modify this function to allow for multiple points (or labels in this instance) on an image? My problem in more detail can be found here. I’m trying to build a model that identifies cars and nearly have it working but the datablock api is crashing b/c my labels or of varying lengths (i.e. some images with 0 or 1 point and others with 10 points etc).