Ok, I’ve written a custom training loop that looks like this:
def train(model, opt, phases, callbacks, epochs, device, loss_fn):
model.to(device)
cb = callbacks
cb.training_started(phases=phases, optimizer=opt)
for epoch in range(1, epochs + 1):
cb.epoch_started(epoch=epoch)
for phase in phases:
n = len(phase.loader)
cb.phase_started(phase=phase, total_batches=n)
is_training = phase.grad
model.train(is_training)
for batch in phase.loader:
phase.batch_index += 1
cb.batch_started(phase=phase, total_batches=n)
x, y = place_and_unwrap(batch, device)
with torch.set_grad_enabled(is_training):
cb.before_forward_pass()
out = model(x)
cb.after_forward_pass()
loss = loss_fn(out, y)
if is_training:
opt.zero_grad()
cb.before_backward_pass()
loss.backward()
cb.after_backward_pass()
opt.step()
phase.batch_loss = loss.item()
cb.batch_ended(phase=phase, output=out, target=y)
cb.phase_ended(phase=phase)
cb.epoch_ended(phases=phases, epoch=epoch)
cb.training_ended(phases=phases)
A couple of points about the implementation:
- Quick Draw Dataset
- No
fastaidependencies - No
Pathobjects stored in memory (reading data directly frompd.DataFrameand rendering images on the fly) - Direct usage of
torch.DataLoaderclasses, the transformations are taken from thetorchvision num_workers=12
Here are a memory usage plots:
The process is killed in the middle of the training epoch. So we can suppose that the problem is somewhere inside the torch package. (Except if my training loop contains the exact same bug as one in the fastai library which sounds like a very unusual coincidence
)
I am going to roll-back to 0.4.1 and see if the problem was introduced in the recent master or exists in the stable version as well.
Can’t claim for sure but it seems that, at least, pytorch-nightly has a problem with data loaders leaking memory when num_workers > 0.
I am going to share the implementation of the training loop I have. It is a bit involved because tries to mirror at least a couple of features that fastai includes. (Callbacks and cyclic schedule mostly). However, it shows a memory issue and probably could be helpful demonstration/starting point if we decide post to PyTorch’s forums.
