EarlyStoppingCallback with to_distributed

When using the EarlyStoppingCallback for distributed training on two GPUs training stops on one GPU but keeps going on the other GPU. Does this callback support distributed training?

Pseudo code looks like:

torch.distributed.init_process_group(backend='nccl', init_method='env://')
learner = vision.unet_learner(data, models.resnet34, ...).to_distributed()
learner.fit(200, lr)

Any help appreciated.

That is weird since all metrics are computed and all reduced across GPUs. Are you using a metric or the validation loss to stop training?
Validation loss is not all reduced (this is fixed in v2) so maybe that’s the reason.

Thanks for your reply. I was using validation loss but have now switched to using one of the metrics and am seeing the same thing. In addition it does not appear to be stopping on the first GPU at the right time (the patience is set to 30 and it is stopping despite there being a max value of the metric less than 30 epochs previously).

I am also using the ReduceLROnPlateau callback so perhaps it’s some interaction with this. I can try to test with a small synthetic example to debug further.

Would you suggest moving to v2 yet?

v2 is still in pre-release mode so I wouldn’t recommend switching just yet, maybe next month.
Any help investigating this bug is welcome in the meantime :wink:

I think I have found the source of the issue. I was using a custom metric class which does not reduce across processes as is done in the AverateMetric class for metric functions. I thought the distributed trainer might average the metrics as it does write the metrics to per process npy files and does have a method (read_metrics) which can read over these and average them. However this function is never called. I managed to get the EarlyStoppingCallback callback working as expected with my metric by adding custom reduction code to the callback class.

Thanks again for your suggestions.

1 Like

In fastai v2, the tensors to the metrics will have been gathered before being passed, so this will require less boilerplate.