Speeding Up fastai2 Inference - And A Few Things Learned

A little update to this, so I found the bottleneck (duh) batches are done on the fly, doing the dl = test_dl just sets up the pipeline.

I’ve gotten it down to just a hair under a second.

Here were my steps:

  1. Build PipeLine manually. There was a chunk of overhead being done in the background I could avoid. My particular problem used PILImage.create, Resize, ToTensor, IntToFloatTensor, and Normalize. As such I made a few pipelines (also notice the Normalzie, couldn’t quite get it to work on the Pipeline for some reason:
type_pipe = Pipeline(PILImage.create)
item_pipe = Pipeline(Resize(224), ToTensor())
norm = Normalize.from_stats(*imagenet_stats)
i2f = IntToFloatTensor()
  1. Next comes actually applying it into batches:
for im in im_names:
    batch.append(item_pipe(type_pipe(im)))
    k += 1
    if k == 50:
        batches.append(torch.cat([norm(i2f(b.cuda())) for b in batch]))
        batch = []
        k = 0

Now this single step right here is what shaved off the most time for me. The rest is the usual for predictions (so we just got rid of the dl = learner.dls.test_dl)

How much time did I shave? We went from our last time of 1.3 seconds down to 937ms for 100 images, so I was able to shave off even more time. I should also note that half of this time is just grabbing the data via PIL.create :slight_smile:

Here’s the entire script:

type_tfms = [PILImage.create]
item_tfms = [Resize(224), ToTensor()]
type_pipe = Pipeline(type_tfms)
item_pipe = Pipeline(item_tfms)
norm = Normalize.from_stats(*imagenet_stats)
i2f = IntToFloatTensor()


batches = []
batch = []
outs = []
inps = []
k = 0
for im in im_names:
    batch.append(item_pipe(type_pipe(im)))
    k += 1
    if k == 50:
        batches.append(torch.cat([norm(i2f(b.cuda())) for b in batch]))
        batch = []
        k = 0

learner.model.eval()
with torch.no_grad():
    for b in batches:
        outs.append(learner.model(b))
        inps.append(b)

inp = torch.stack(inps)
out = torch.stack(outs)
dec = learner.dls.decode_batch((*tuplify(inp), *tuplify(out)))

(PS if this needs further explanation let me know, just the mad ramblings of man at 2am…)

(PPS, on a resnet18 you can get close to real time with this too on a GPU, with 3 images I clocked it at ~39ms)

9 Likes