Custom Collate Function for Image Points

Hello,

I’m looking to make my own custom collate function for single image points. For example, the images in my data set have an object whose center is marked with a single red dot. Since the Fastai library appears to only have a collate function for bounding boxes, this is why I’m wanting to make my own. I’m trying to understand the source code for bb_pad_collate, though I’m having some trouble understanding how the whole function comes together/what every line does. Here is the source code for bb_pad_collate:

def bb_pad_collate(samples:BatchSamples, pad_idx:int=0) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]:
    "Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`."
    max_len = max([len(s[1].data[1]) for s in samples])
    bboxes = torch.zeros(len(samples), max_len, 4)
    labels = torch.zeros(len(samples), max_len).long() + pad_idx
    imgs = []
    for i,s in enumerate(samples):
        imgs.append(s[0].data[None])
        bbs, lbls = s[1].data
        bboxes[i,-len(lbls):] = bbs
        labels[i,-len(lbls):] = lbls
    return torch.cat(imgs,0), (bboxes,labels)

I’d appreciate it if anyone could give a brief summary of how this function works, and if possible some steps to take for creating an image points collate function based off of bb_pad_collate.

Thanks!
Wayde

just curious, did you figure it out? :slight_smile: I was looking at the collate function and saw your post.

Sort of. For my collate function I just removed images where the point ended up outside the actual image so the model wouldn’t get confused. Here’s the function that I wrote:

def ip_pad_collate(samples:BatchSamples):
    samples = to_data(samples)
    
    count = 1 
    while count < len(samples):
        for i, sample in enumerate(samples):
            if (sample[1].shape[0] == 0): 
                del samples[i]
        count+=1
            
    return tensor([s[0].numpy() for s in samples]), tensor([s[1].numpy() for s in samples])

Basically, I look through every sample in my image batch and check to see if the image point is outside the image size. If it is, I just delete it from the batch. Here’s the working notebook of it in action.

Hope this helps!

1 Like

Thanks! ur awesome! :slight_smile:

@waydegg I see at one point you ran into a similar issue that I have regarding the need to modify the collate function in order to get your model working. Would you have any insight into how I can modify this function to allow for multiple points (or labels in this instance) on an image? My problem in more detail can be found here. I’m trying to build a model that identifies cars and nearly have it working but the datablock api is crashing b/c my labels or of varying lengths (i.e. some images with 0 or 1 point and others with 10 points etc).

Thanks