Gradient accumulation with fp16 training

Hi I ran into an issue when trying to use a Gradient Accumulation callback together with fp16 training. The loss is super high, like 1000000.

I’m doing image segmentation with the unet_learner and xresnet50 as the base architecture.

Here the code for the callback:

class GradAccumCallback(Callback):
    def __init__(self, num_batches):
        self.num_batches = num_batches
    count = 0
    def after_backward(self):
        self.count += 1
        if (self.count % self.num_batches) == 0:
            pass
        else:
            raise CancelBatchException()

It is supposed to skip the zeroing of the gradients and the optimizer step in the training loop, unless the current batch is a multiple of num_batches, and it seems to work fine in single precision training.

I looked at the MixedPrecision callback and I think it might have something to do with this line as there the gradients are zeroed, but I’m not sure.

Also, I tested the model in fp16 without the callback and it worked fine, only problem is that I can only fit ~4 images on the gpu.

Any ideas? A similar callback worked just fine with fp16 in fastaiv1

1 Like

Okay, it seems like I solved it.

I needed to set run_after=MixedPrecision in the GradAccum callback, because otherwise the CancelBatch exception stop the after_backward method of the MixedPrecision callback to run, learn.show_training_loop was really helpful!

Here the fixed callback if anyone is interested:

class GradAccumCallback(Callback):
    count = 0
    run_after=MixedPrecision
    def __init__(self, num_batches):
        self.num_batches = num_batches
    def after_backward(self):
        self.count += 1
        if (self.count % self.num_batches) == 0:
            pass
        else:
            raise CancelBatchException()
9 Likes

I’m trying your GradientAccumulator callback now. If I’m not mistaken the num_batches you pass into the callback increases the effective batch size to num_batches x batch_size?

Have you had any issues with using this callback? I’m benchmarking a few models with bs=64 and bs=2, num_batches=32, and seeing some performance degradation against bs=64, but overall still better than bs=2.