So today I went through and looked at trying to do inference efficiently. First thing you’ll notice is that there will be a decent gain in speed here, the main reason I believe is because fastai is attempting to do a lot on the back end and make other various functions/utilities easy to use and try to do it as efficiently as possible which may show the lack of speed. (also if I am wrong here Jeremy or Sylvain please don’t hesitate to call me out ) With that out of the way, let’s get into a few bits:
Trying to Speed Up with Jit Scripting
Something done often in Kaggle (thanks @DrHB!) to speed up inference is saving the model away as a jit script. I found that this has no noticeable change in time (not even 0.01ms)
Doing fastai2 efficiently
Now let’s get into the juicy stuff. Let’s presume the following pipeline:
- Make a
test_dl
(which for the record is super efficent, micro-seconds in time!). For 1,000 images I clocked it in at 669 micro seconds - Run predictions
- Decode your predictions (such as if we have keypoints, get back the actual values)
We’re going to then further look at 3 different bits of code to look at for this. First, using fastai
regularly, then two other ways of writing similar code. (again this is not to bash on the library, and this is an extreme edge case of people who need to get predictions/speed as fast as possible!)
- Note: Baselines taken on 100 images with a batch size of 50 (I looked at 64, wasn’t as fast)
fastai straight
Let’s presume we have the following prediction script, which is pretty standard for fastai:
dl = learner.dls.test_dl(imgs)
inp, preds,_,dec_preds = learner.get_preds(dl=dl, with_input=True, with_decoded=True)
full_dec = learner.dls.decode_batch((*tuplify(inp),*tuplify(dec_preds)))
This loop (doing a %%time) takes approximately 1.98 seconds in total, which is about 0.02 seconds per image. If for many this is fast enough, that’s fine. I’ll be trying to make it faster
Getting rid of get_preds
The next bit we’ll try is getting rid of get_preds
. Here’s what this new code involves:
dec_batches = []
dl = learn.dls.test_dl(imgs)
learn.model.eval()
with torch.no_grad():
for batch in dl:
res = learn.model(batch[0])
inp = batch[0]
dec_batches.append(learn.dls.decode_batch((*tuplify(inp), *tuplify(res))))
Notice specifically we’re calling learn.model
directly and we grab all the batches from my dataloader in succession, and afterwards do the decoding. This brings our time down to 1.7 seconds, so we shaved off .2 seconds
A better way
We can shave even more time by instead of decoding each batch as it comes out, we combine them all into one big batch we decode at the end. Remember, as we’re not sending this to the model as a big batch, we’re okay combining it after the fact!!
What does this look like? Something like so:
outs = []
inps = []
dl = learner.dls.test_dl(imgs)
learn.model.eval()
with torch.no_grad():
for batch in dl:
outs.append(learner.model(batch[0]))
inps.append(batch[0])
outs = torch.stack(outs)
inps = torch.stack(inps)
dec = learner.dls.decode_batch((*tuplify(inps), *tuplify(outs)))
So we can see we combine all the batches via torch.stack
and pass this into decode_batch
. But what does this time gain look like? It turns into 1.3 seconds! We decreased the time by 33%!
There is probably ways to decrease this further, remember though this is an advanced idea in the sense you’re walking away from purely fastai code and instead utilizing the library with other code to take it further.
Hope this helps
Also, this was just a find after a quick day, as I mentioned there is probably better ways of doing so, if you find them post them in this thread!
Edit:
The new best way:
type_tfms = [PILImage.create]
item_tfms = [Resize(224), ToTensor()]
type_pipe = Pipeline(type_tfms)
item_pipe = Pipeline(item_tfms)
norm = Normalize.from_stats(*imagenet_stats)
i2f = IntToFloatTensor()
batches = []
batch = []
outs = []
inps = []
k = 0
for im in im_names:
batch.append(item_pipe(type_pipe(im)))
k += 1
if k == 50:
batches.append(torch.cat([norm(i2f(b.cuda())) for b in batch]))
batch = []
k = 0
learner.model.eval()
with torch.no_grad():
for b in batches:
outs.append(learner.model(b))
inps.append(b)
inp = torch.stack(inps)
out = torch.stack(outs)
dec = learner.dls.decode_batch((*tuplify(inp), *tuplify(out)))