Hi, it seems we have an easy way to speed up the training, 2x for small batches on fp32. The DDPM_v2 went down from 10 min (on 2080it) to less than 5 min without changing the batch size.
Apparently for small batches MetricsCB and ProgressBarCB are the bottlenecks, as they are synchronising the gpu with cpu (using to_cpu on each batch). If we make them lazy and read the metrics from the card only after full epoch we allow the batches of data to be prepared in parallel to the model execution.
The speed up is not that large for fp16 as the batch is large enough. But still it goes down from 160s to 111s.
I’ve tried to improve the training even more using cuda graphs, they are unfortunately a bit unstable and give almost no improvements.
Here is DDPM_v2 that shows both the original code with execution timing added as well the improved lazy callbacks.