Ok, I’ve written a custom training loop that looks like this:
def train(model, opt, phases, callbacks, epochs, device, loss_fn):
model.to(device)
cb = callbacks
cb.training_started(phases=phases, optimizer=opt)
for epoch in range(1, epochs + 1):
cb.epoch_started(epoch=epoch)
for phase in phases:
n = len(phase.loader)
cb.phase_started(phase=phase, total_batches=n)
is_training = phase.grad
model.train(is_training)
for batch in phase.loader:
phase.batch_index += 1
cb.batch_started(phase=phase, total_batches=n)
x, y = place_and_unwrap(batch, device)
with torch.set_grad_enabled(is_training):
cb.before_forward_pass()
out = model(x)
cb.after_forward_pass()
loss = loss_fn(out, y)
if is_training:
opt.zero_grad()
cb.before_backward_pass()
loss.backward()
cb.after_backward_pass()
opt.step()
phase.batch_loss = loss.item()
cb.batch_ended(phase=phase, output=out, target=y)
cb.phase_ended(phase=phase)
cb.epoch_ended(phases=phases, epoch=epoch)
cb.training_ended(phases=phases)
A couple of points about the implementation:
- Quick Draw Dataset
- No
fastai
dependencies - No
Path
objects stored in memory (reading data directly frompd.DataFrame
and rendering images on the fly) - Direct usage of
torch.DataLoader
classes, the transformations are taken from thetorchvision
num_workers=12
Here are a memory usage plots:
The process is killed in the middle of the training epoch. So we can suppose that the problem is somewhere inside the torch
package. (Except if my training loop contains the exact same bug as one in the fastai
library which sounds like a very unusual coincidence )
I am going to roll-back to 0.4.1
and see if the problem was introduced in the recent master
or exists in the stable version as well.
Can’t claim for sure but it seems that, at least, pytorch-nightly
has a problem with data loaders leaking memory when num_workers > 0
.
I am going to share the implementation of the training loop I have. It is a bit involved because tries to mirror at least a couple of features that fastai
includes. (Callbacks and cyclic schedule mostly). However, it shows a memory issue and probably could be helpful demonstration/starting point if we decide post to PyTorch’s forums.