DeviceDataloader is slow for first iteration

I have a Pytorch Dataset train_ds consisting of 161 items. Each item contains two float32 tensors of shapes [4, 512, 512] and [3, 1024, 1024].

If I pass that DataSet into a vanilla Pytorch DataLoader:

train_dl = DataLoader(train_ds)

and iterate through the batches (bs = 1):

%%timeit
for x,y in iter(train_dl):
next

It takes about 8.9 seconds.

However, if I pass my dataset train_ds to a fastai databunch and learner:

databunch = DataBunch.create(train_ds,val_ds, device=device, bs =1)
learn = Learner(databunch,net,callback_fns=[ShowGraph], loss_func = criterion)

Then iterate through the learner’s train_dl:

%%time
for x,y in iter(learn.data.train_dl):
next

This takes 22 seconds.

Notably, the first iteration of that loop takes most of the time:

%%time
for x,y in iter(learn.data.train_dl):
break

This takes 20 seconds.

My question is: what happens on the first iteration of the dataloader that is so slow? I tried looking at the DeviceDataLoader source code but couldn’t figure it out.

The reason I’m bringing this up is because my fastai training loops have that same 20 second gap at the beginning of each epoch (and something similar before the validation part at the end of the epoch).

Thanks for your help!

It’s because that’s the first time you put a batch on the GPU, which is also when PyTorch does some initial setup that takes some time.

2 Likes

Thanks for your reply. I wonder if there is something more to it, because it happens at the beginning of each epoch.
I wonder if there’s a process that iterates through the entire dataloader when loading the first batch.

I’ll try to dig more into it.

I think I’ve figured it out. I was using the default (16) for num_workers for DataBunch.create() which turned out to be suboptimal.
By reducing it to 3, I divided the duration of an epoch by 2.

On the top of that: it may be worth trying to deactivate multithreading, by setting num_workers=0. I am doing transfer learning to build a classifier for a small dataset of text data, and the latency between epochs went from 1min with 8 workers, to 2sec with num_workers=0.