AdaHessian

Hi all,

It finally seems like someone has managed to create a viable second order optimizer, as presented in https://arxiv.org/pdf/2006.00719.pdf, and even provided an implementation in PyTorch.

Has anyone had the opportunity to try this out yet? If the results are promising, it may be a good candidate to port to v2.

6 Likes

I have not yet, but I’ll run some ImageWoof tests in a moment :slight_smile:

Edit: looks like it’s not a 1:1, @LessW2020 would probably be the best for this :slight_smile:

1 Like

lol, I was all over this and summary it works extremely well with one huge caveat - namely it will consume about 2x gpu memory of what you normally need to run your model.

Thus, I could only run for example a resneST50 and not a 101 b/c I couldn’t get it to fit in memory.
I’m actually close to buying a Titan GPU with 24GB gpu memory expressly to be able to work with adahessian.

Here’s an example of the difference:

and then… adahessian:

7 Likes

Did you port it over to v2 yet? :slight_smile: (or on the TODO). Also fantastic work, I figured you were on this :wink:

1 Like

no lol, I have not ported it yet. I’ve been super busy at work atm with building multiple AI pipelines, but I am likely buying the Titan GPU with 24GB next week… exactly for AdaHessian.

I posted a ‘single import’ version of adahessian here:

And you have to update your training loop as below:

config for training loop:

        loss.backward(create_graph=True)
        _, gradsH = get_params_grad(model)
        optimizer.step(gradsH)
7 Likes

lol what is this wizardary?? Looking forward to digging in to this!

1 Like

Me too!

A long time ago I coded a simple-minded 2nd order optimizer that failed in several ways (trapped in local minima, unstable 2nd derivative). It looks like these authors have addressed a number of such issues and got a method to work.

Thanks so much for translating to PyTorch.

2 Likes

Wow the plot looks unreal. Maybe we should try on a tuffer dataset

I have a working version for fastai2 now, but my code is still a bit rough around the edges I feel, performance isn’t quite on par with AdamW yet, but sharing here for a start. I’ve only been testing it with machine translation with a Transformer model so can’t guarantee this code will work without a few tweaks for vision, I think it should tho

It is a bit of a pig when it comes to memory usage, I can only fit a 57M parameter Transformer (2 enc/2 dec) on my 2080 (without MixedPrecision for now) whereas before a 89M model (6/6) worked fine, but mixed precision will help with that and I’m hoping the improvement in performance will make up for it.

Note this implementation doesn’t work with MixedPrecision yet, I think there is something going on with scaling the gradients I need to figure out, will hopefully get it figured out soon.

Weights and Biases experiment runs are here

Callback + Optimizer + Learner patch

Right now the optimizer looks more or less the same as AdamW, however I needed to add a callback to calculate the diagonal of the hessian as I couldn’t do it from within the optimizer functions (happy to hear if someone manages to figure this out tho!). The optimizer code is more or less the same as Adam, the callback is where the magic happens.

Shoutout to @LessW2020 for his code, helped a lot in getting this working

Optimizer

average_sqr_diag_hessian - almost identical to average_sqr_grad

def average_sqr_diag_hessian(p, sqr_mom, dampening=True, sqr_avg_diag_hessian=None, hutchinson_trace=None, **kwargs):
    if sqr_avg_diag_hessian is None: sqr_avg_diag_hessian = torch.zeros_like(p.grad.data)
    damp = 1-sqr_mom if dampening else 1.
    sqr_avg_diag_hessian.mul_(sqr_mom).addcmul_(hutchinson_trace, hutchinson_trace, value=damp)
    return {'sqr_avg_diag_hessian': sqr_avg_diag_hessian}

adahessian_step - similar to adam_step

def adahessian_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg_diag_hessian, hessian_power, eps, **kwargs):
    "Step for Adam with `lr` on `p`"
    debias1 = debias(mom,     1-mom,     step)
    debias2 = debias(sqr_mom, 1-sqr_mom, step)
    p.data.addcdiv_(grad_avg, ((sqr_avg_diag_hessian/debias2).sqrt() ** hessian_power) + eps, value = -lr / debias1)    
    return p

AdaHessian - tying it all together (DONT USE ME FOR NOW)

#@log_args(to_return=True, but_as=Optimizer.__init__)
#def AdaHessian(params, lr=0.1, hessian_power=1, hutchinson_trace=None, mom=0.9, #sqr_mom=0.98, eps=1e-4, wd=0.0, decouple_wd=True):
#    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
#    cbs = [weight_decay] if decouple_wd else [l2_reg]
#    cbs += [partial(average_grad, dampening=True), average_sqr_diag_hessian, step_stat, adahessian_step]
#    return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, hessian_power=hessian_power, eps=eps, wd=wd)

Callback

class HutchinsonTraceCallback(Callback):
    run_before=MixedPrecision
    
    def __init__(self, block_length=1):
        self.block_length = block_length
        
    def _clip_grad_norm(self, max_norm=0., params=None):
            """
            From FairSeq opimizer - Clips gradient norm.
            """
            if max_norm > 0:
                return torch.nn.utils.clip_grad_norm_(params, max_norm)
            else:
                return math.sqrt(sum(p.grad.data.norm()**2 for p in params if p.grad is not None))  
            
    def after_backward(self):
        """
            compute the Hessian vector product with a random vector v, at the current gradient point,
            i.e., compute the gradient of <gradsH,v>.
            :param gradsH: a list of torch variables
            :return: a list of torch tensors
        """
        device = self.learn.dls.device

        params, grads = [], []
        for p in self.learn.model.parameters():
            if p.requires_grad and p.grad is not None: 
                params.append(p)
                grads.append(p.grad)
    
        grad_norm = self._clip_grad_norm(max_norm=0., params=params)   # Not sure if this is needed...
        
        zs = [torch.randint_like(p, high=2).to(device) * 2.0 - 1.0 for p in params]
       
        h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=False)
        
        # @LESSW CODE
        hutchinson_trace = []
        for hz, z in zip(h_zs, zs):
            param_size = hz.size()
            if len(param_size) <= 2:  # for 0/1/2D tensor
                tmp_output = torch.abs(hz * z) + 0. #.float()
                hutchinson_trace.append(tmp_output) # Hessian diagonal block size is 1 here.
            elif len(param_size) == 4:  # Conv kernel
                tmp_output = torch.abs(torch.sum(torch.abs(hz * z) + 0., dim=[2, 3], keepdim=True)).float() / z[0, 1].numel() # Hessian diagonal block size is 9 here: torch.sum() reduces the dim 2/3.
                hutchinson_trace.append(tmp_output)
                
        # Add hutchinson_trace to optimizer state
        for i, (_,_,state,_) in enumerate(self.learn.opt.all_params(with_grad=True)):
            state['hutchinson_trace'] = hutchinson_trace[i]
        
        def after_step(self): # Is this needed?            
            for h in hutchinson_trace: 
                h.zero_()
                
            for i, (_,_,state,_) in enumerate(self.learn.opt.all_params(with_grad=True)):
                state['hutchinson_trace'].zero_() 
        

Learner Patch

@patch 
def _backward(self:Learner): self.loss.backward(create_graph=True)

Just add the callback and optimizer to your Learner and you should be good to go!

9 Likes

That looks great, can’t wait to try this out, will try and put some time aside this week! Forgive my ignorance, but would the pure PyTorch implementation work with OptimWrapper, and then just patching the backward and step methods in the learner with the training loop update?

1 Like

Yep I considered it briefly, but wanted to see if I could fold it into the fastai framework (and make life difficult for myself :sweat_smile:) , might also be worth a PR after some more testing if its performance in practice can meet expectations

Hi all, I have been trying to use this with fastai2==0.0.21, but keep hitting the same error regardless of whether I’m using @morgan’s implementation or OptimWrapper.

It is the rather unhelpful RuntimeError: got 171 tensors and 9 gradients which is raised from the get_trace function.

For clarity, the OptimWrapper code I’m using is:

def adahessian(param_groups, **kwargs):
    return OptimWrapper(Adahessian([{'params': ps, **kwargs} for ps in param_groups]))

@patch 
def one_batch(self:Learner, i, b):
    self.iter = i
    try:
        self._split(b);                                  self('begin_batch')
        self.pred = self.model(*self.xb);                self('after_pred')
        if len(self.yb) == 0: return
        self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')
        if not self.training: return
        self.loss.backward(create_graph=True);           self('after_backward')
        _, gradsH = get_params_grad(self.model)
        self.opt.step(gradsH)
        self('after_step')
        self.opt.zero_grad()
    except CancelBatchException:                         self('after_cancel_batch')
    finally:                                             self('after_batch')

Did anyone else hit this, or have any ideas?

Thanks!

I think thats being raised in torch.autograd.grad, which is here in @LessW2020’s implementation

Make sure you’re only giving it params that have a gradient, the number of params should equal the number of grads

params_g=[]
for p in params:
    if p.grad is None:
        continue
   else: 
        params_g.append(p)

Even though that code already has something like that…strange

Hey, yes you are right about it being a torch.autograd.error, at the moment I am basically using the cnn_learner default settings:

learn = cnn_learner(dls, xresnet50, metrics=error_rate, opt_func=adahessian).to_fp16()
learn.unfreeze()
learn.fit(1)

so I would think that it is only passing in appropriate parameters! Strange!

Modified how the parameters were gathered in @LessW2020 code and it worked here with fastai. I point out my changes with <---xxx :


class Adahessian(Optimizer):
    """Implements Adahessian algorithm.
    It has been proposed in `ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning`.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 0.15)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-4)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        hessian_power (float, optional): Hessian power (default: 1)
    """

    def __init__(self, params, lr=0.15, betas=(0.9, 0.999), eps=1e-4,
                 weight_decay=0, hessian_power=1):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 0: {}".format(
                    betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 1: {}".format(
                    betas[1]))
        if not 0.0 <= hessian_power <= 1.0:
            raise ValueError("Invalid Hessian power value: {}".format(hessian_power))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, hessian_power=hessian_power)

        super(Adahessian, self).__init__(params, defaults)

    def get_trace(self, params, gradsH):  # <--- Passed params in here
        """
        compute the Hessian vector product with a random vector v, at the current gradient point,
        i.e., compute the gradient of <gradsH,v>.
        :param gradsH: a list of torch variables
        :return: a list of torch tensors
        """

        #params = self.param_groups[0]['params']
       
        v = [torch.randint_like(p, high=2, device='cuda') for p in params]
        for v_i in v:
            v_i[v_i == 0] = -1
            
        #print(len(gradsH[0]), len(params[0]))
        #print(params) 
            
        hvs = torch.autograd.grad(
            gradsH,
            params,
            grad_outputs=v,
            only_inputs=True,
            retain_graph=True)

        hutchinson_trace = []
        for hv, vi in zip(hvs, v):
            param_size = hv.size()
            if len(param_size) <= 2:  # for 0/1/2D tensor
                tmp_output = torch.abs(hv * vi)
                hutchinson_trace.append(tmp_output) # Hessian diagonal block size is 1 here.
            elif len(param_size) == 4:  # Conv kernel
                tmp_output = torch.abs(torch.sum(torch.abs(
                    hv * vi), dim=[2, 3], keepdim=True)) / vi[0, 1].numel() # Hessian diagonal block size is 9 here: torch.sum() reduces the dim 2/3.
                hutchinson_trace.append(tmp_output)
        
        return hutchinson_trace

    def step(self, gradsH, closure=None):
        """Performs a single optimization step.
        Arguments:
            gradsH: The gradient used to compute Hessian vector product.
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        params = [ps['params'][0] for ps in self.param_groups]    # <--- grouped params here   
            
        # get the Hessian diagonal
        hut_trace = self.get_trace(params, gradsH)

        #params = [ps['params'][0] for ps in self.param_groups]
        
#         for group in self.param_groups:
#             for i, p in enumerate(group['params']):
        for i, group in enumerate(self.param_groups):   # <--- changed loop here   
            
            p = group['params'][0]
        
            if p.grad is None:
                continue

            grad = deepcopy(gradsH[i].data)
            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state['step'] = 0
                # Exponential moving average of gradient values
                state['exp_avg'] = torch.zeros_like(p.data)
                # Exponential moving average of Hessian diagonal square values
                state['exp_hessian_diag_sq'] = torch.zeros_like(p.data)

            exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']

            beta1, beta2 = group['betas']

            state['step'] += 1

            # Decay the first and second moment running average coefficient
            exp_avg.mul_(beta1).add_(1 - beta1, grad)
            exp_hessian_diag_sq.mul_(beta2).addcmul_(
                1 - beta2, hut_trace[i], hut_trace[i])

            bias_correction1 = 1 - beta1 ** state['step']
            bias_correction2 = 1 - beta2 ** state['step']

            # make the square root, and the Hessian power
            k = group['hessian_power']
            denom = (
                (exp_hessian_diag_sq.sqrt() ** k) /
                math.sqrt(bias_correction2) ** k).add_(
                group['eps'])

            # make update
            p.data = p.data - \
                group['lr'] * (exp_avg / bias_correction1 / denom + group['weight_decay'] * p.data)
            
        # Zero gradsH
        for h in hut_trace:
            if h.grad is not None:
                print('h yay')
        
        for g in gradsH:
            if g.grad is not None:
                print('g yay')
                g.grad.detach_()
                g.grad.zero_()

        return loss
    
    def zero_grad(self):
        r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
        for group in self.param_groups:
            for i,p in enumerate(group['params']):
                if p.grad is not None:
                    p.grad.detach_()
                    p.grad.zero_()
#                 if self.hut_trace[i].grad is not None:
#                     print('yay')
#                     self.hut_trace[i].grad.detach_()
#                     self.hut_trace[i].grad.zero_()


3 Likes

Edit, I was wrong about the paper notation, removed it

Getting some nice results on my little 2/2 Transformer (still no fp16) with the code above.

1 Like


Full working notebook from https://github.com/davda54/ada-hessian on CoLab
https://colab.research.google.com/drive/1IcBOCgereZbjASnlVjXPGbJqp-NYjK1L?usp=sharing

@Johnyquest the notebook is private

try this https://colab.research.google.com/drive/1IcBOCgereZbjASnlVjXPGbJqp-NYjK1L?usp=sharing

1 Like

Changed the privacy setting. You should be able to access it now.