Implementing NFNets NormalizerFreeNets in fastai


Has anyone tried to implement NFNets by Deepmind in fastai? (A recent paper on normalizer free nets by deep mind.)

Their contributions to SOTA results include

  1. Adaptive Gradient Clipping aka AGC, where gradients are clipped by ||gradient||/||weights|| when ||gradient||/||weights|| is greater than some lambda, and
  2. A normalizer free network aka NFNet architecture, a result of architecture search.

Their model does not seem to improve inference speed over efficientnet (for equivalent accuracy), but improves training speed by a large margin (5-8x) by removing batch norm from their models. I believe this has significant meaning for fastai, as it provides a great opportunity for faster training with transfer learning.

A summary and critique by yannic is done here. As yannic points out, it is not clear whether the improvements resulting in SOTA is from AGC or from architecture search. To me, It looks more like they implemented a bigger model resulting in a higher accuracy, so I’m not so much interested in the SOTA results themselves, but the removal of BN and its effects on transfer learning in medical datasets.

Furthermore, it would be nice to implement an AGC of individual samples of the minibatch before the mean() or sum() of the loss, if at all possible, rather than the clipping after the mean() or sum() of the loss as the paper suggests. Yannic suggests this at 23:30.

I’m new to fastai forums, and I’m happy to delete or move this topic to a subtopic if requested. I have used fastai since v1. One of my works, with fastai cited, is Explaining the Rationale of Deep Learning Glaucoma Decisions with Adversarial Examples - PubMed (

Thanks in advance! :slight_smile:


Ross Wightman is in the process of implementing this in his amazing timm library and thanks to @muellerzr’s timm_learner implementation, it should be trivial to bring NFNet to fastai.

Regarding adaptive gradient clipping, it’s best done in a callback (as is the case for most training changes in fastai). As a starting point, you can look at what’s already being done for regular gradient clipping (although it uses the PyTorch function): code


I am personally mainly interested in the Adaptive Gradient Clipping and how it compares to more conventional neural nets with batchnorm.

I wonder how it works with gradient accumulation. In my experiments I had worse performance with batchnorm using gradient accumulation, possibly because of the more varying gradients at small batch sizes (dividing by less consistent std). Without the need to have batch statistics for normalization this might make gradient accumulation perform closer to regular batchnorm levels at low batch sizes.

If anyone gets an implementation–I’d love to see :slight_smile:
I tried it on the Cassava dataset and it was incredibly fast but the accuracy was poor and never increased. So I’d need to look at your recommendations.


(I came upon this thread after searching the web for metrics on AGC vs. BatchNorm performance outside of the NFNet architecture.)

Thanks @ilovescience. Looks like Wightman’s timm now implements NFNets as of March 7. And the Adaptive Gradient Clipping (AGC) has been in there since 2016.

Ditto on @marii’s question. I’m really curious how well AGC will work vs. BatchNorm for my own (custom) models. Given all the architecture optimizations they did in the NFNets paper (via Neural Architecture Search) in order to “beat” EfficientNets, it’s unclear whether “batchnorm is dead” or not for the rest of us. :wink:


Coming back to this, I tried implementing AGC with the link @ilovescience and code from GitHub - vballoli/nfnets-pytorch: NFNets and Adaptive Gradient Clipping for SGD implemented in PyTorch. Find explanation at

Hope this helps. Its working (training is occuring) but it may be prone to some bugs. If you find any please point it out, I’m happy to receive any constructive criticism.


class AdaptiveGradientClip(Callback):
Adaptive Gradient Clipping from
Code adapted from
Implemented for FastAi v2
clipping: clipping factor defaults to 0.01 (same as paper)
eps: the eps
ignore_agc: list of str names of parameters to ignore when clipping 
def __init__(self,clipping: float = 1e-2, eps: float = 1e-3, ignore_agc=["fc"]): 
def before_step(self): 
    for name, p in self.named_parameters():
        for n in self.ignore_agc:
            if n in name:
        if p.grad is None:

        param_norm = torch.max(self.unitwise_norm(
            p.detach()), torch.tensor(self.eps).to(p.device))
        grad_norm = self.unitwise_norm(p.grad.detach())
        max_norm = param_norm * self.clipping
        trigger = grad_norm > max_norm
        clipped_grad = p.grad * \
            (max_norm / torch.max(grad_norm,
        p.grad.detach().data.copy_(torch.where(trigger, clipped_grad, p.grad))

def unitwise_norm(self,x: torch.Tensor):
    if x.ndim <= 1:
        dim = 0
        keepdim = False
    elif x.ndim in [2, 3]:
        dim = 0
        keepdim = True
    elif x.ndim == 4:
        dim = [1, 2, 3]
        keepdim = True
        raise ValueError('Wrong input dimensions')
    return torch.sum(x**2, dim=dim, keepdim=keepdim) ** 0.5