AdaHessian

That looks great, can’t wait to try this out, will try and put some time aside this week! Forgive my ignorance, but would the pure PyTorch implementation work with OptimWrapper, and then just patching the backward and step methods in the learner with the training loop update?

1 Like

Yep I considered it briefly, but wanted to see if I could fold it into the fastai framework (and make life difficult for myself :sweat_smile:) , might also be worth a PR after some more testing if its performance in practice can meet expectations

Hi all, I have been trying to use this with fastai2==0.0.21, but keep hitting the same error regardless of whether I’m using @morgan’s implementation or OptimWrapper.

It is the rather unhelpful RuntimeError: got 171 tensors and 9 gradients which is raised from the get_trace function.

For clarity, the OptimWrapper code I’m using is:

def adahessian(param_groups, **kwargs):
    return OptimWrapper(Adahessian([{'params': ps, **kwargs} for ps in param_groups]))

@patch 
def one_batch(self:Learner, i, b):
    self.iter = i
    try:
        self._split(b);                                  self('begin_batch')
        self.pred = self.model(*self.xb);                self('after_pred')
        if len(self.yb) == 0: return
        self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')
        if not self.training: return
        self.loss.backward(create_graph=True);           self('after_backward')
        _, gradsH = get_params_grad(self.model)
        self.opt.step(gradsH)
        self('after_step')
        self.opt.zero_grad()
    except CancelBatchException:                         self('after_cancel_batch')
    finally:                                             self('after_batch')

Did anyone else hit this, or have any ideas?

Thanks!

I think thats being raised in torch.autograd.grad, which is here in @LessW2020’s implementation

Make sure you’re only giving it params that have a gradient, the number of params should equal the number of grads

params_g=[]
for p in params:
    if p.grad is None:
        continue
   else: 
        params_g.append(p)

Even though that code already has something like that…strange

Hey, yes you are right about it being a torch.autograd.error, at the moment I am basically using the cnn_learner default settings:

learn = cnn_learner(dls, xresnet50, metrics=error_rate, opt_func=adahessian).to_fp16()
learn.unfreeze()
learn.fit(1)

so I would think that it is only passing in appropriate parameters! Strange!

Modified how the parameters were gathered in @LessW2020 code and it worked here with fastai. I point out my changes with <---xxx :


class Adahessian(Optimizer):
    """Implements Adahessian algorithm.
    It has been proposed in `ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning`.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 0.15)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-4)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        hessian_power (float, optional): Hessian power (default: 1)
    """

    def __init__(self, params, lr=0.15, betas=(0.9, 0.999), eps=1e-4,
                 weight_decay=0, hessian_power=1):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 0: {}".format(
                    betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 1: {}".format(
                    betas[1]))
        if not 0.0 <= hessian_power <= 1.0:
            raise ValueError("Invalid Hessian power value: {}".format(hessian_power))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, hessian_power=hessian_power)

        super(Adahessian, self).__init__(params, defaults)

    def get_trace(self, params, gradsH):  # <--- Passed params in here
        """
        compute the Hessian vector product with a random vector v, at the current gradient point,
        i.e., compute the gradient of <gradsH,v>.
        :param gradsH: a list of torch variables
        :return: a list of torch tensors
        """

        #params = self.param_groups[0]['params']
       
        v = [torch.randint_like(p, high=2, device='cuda') for p in params]
        for v_i in v:
            v_i[v_i == 0] = -1
            
        #print(len(gradsH[0]), len(params[0]))
        #print(params) 
            
        hvs = torch.autograd.grad(
            gradsH,
            params,
            grad_outputs=v,
            only_inputs=True,
            retain_graph=True)

        hutchinson_trace = []
        for hv, vi in zip(hvs, v):
            param_size = hv.size()
            if len(param_size) <= 2:  # for 0/1/2D tensor
                tmp_output = torch.abs(hv * vi)
                hutchinson_trace.append(tmp_output) # Hessian diagonal block size is 1 here.
            elif len(param_size) == 4:  # Conv kernel
                tmp_output = torch.abs(torch.sum(torch.abs(
                    hv * vi), dim=[2, 3], keepdim=True)) / vi[0, 1].numel() # Hessian diagonal block size is 9 here: torch.sum() reduces the dim 2/3.
                hutchinson_trace.append(tmp_output)
        
        return hutchinson_trace

    def step(self, gradsH, closure=None):
        """Performs a single optimization step.
        Arguments:
            gradsH: The gradient used to compute Hessian vector product.
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        params = [ps['params'][0] for ps in self.param_groups]    # <--- grouped params here   
            
        # get the Hessian diagonal
        hut_trace = self.get_trace(params, gradsH)

        #params = [ps['params'][0] for ps in self.param_groups]
        
#         for group in self.param_groups:
#             for i, p in enumerate(group['params']):
        for i, group in enumerate(self.param_groups):   # <--- changed loop here   
            
            p = group['params'][0]
        
            if p.grad is None:
                continue

            grad = deepcopy(gradsH[i].data)
            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state['step'] = 0
                # Exponential moving average of gradient values
                state['exp_avg'] = torch.zeros_like(p.data)
                # Exponential moving average of Hessian diagonal square values
                state['exp_hessian_diag_sq'] = torch.zeros_like(p.data)

            exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']

            beta1, beta2 = group['betas']

            state['step'] += 1

            # Decay the first and second moment running average coefficient
            exp_avg.mul_(beta1).add_(1 - beta1, grad)
            exp_hessian_diag_sq.mul_(beta2).addcmul_(
                1 - beta2, hut_trace[i], hut_trace[i])

            bias_correction1 = 1 - beta1 ** state['step']
            bias_correction2 = 1 - beta2 ** state['step']

            # make the square root, and the Hessian power
            k = group['hessian_power']
            denom = (
                (exp_hessian_diag_sq.sqrt() ** k) /
                math.sqrt(bias_correction2) ** k).add_(
                group['eps'])

            # make update
            p.data = p.data - \
                group['lr'] * (exp_avg / bias_correction1 / denom + group['weight_decay'] * p.data)
            
        # Zero gradsH
        for h in hut_trace:
            if h.grad is not None:
                print('h yay')
        
        for g in gradsH:
            if g.grad is not None:
                print('g yay')
                g.grad.detach_()
                g.grad.zero_()

        return loss
    
    def zero_grad(self):
        r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
        for group in self.param_groups:
            for i,p in enumerate(group['params']):
                if p.grad is not None:
                    p.grad.detach_()
                    p.grad.zero_()
#                 if self.hut_trace[i].grad is not None:
#                     print('yay')
#                     self.hut_trace[i].grad.detach_()
#                     self.hut_trace[i].grad.zero_()


3 Likes

Edit, I was wrong about the paper notation, removed it

Getting some nice results on my little 2/2 Transformer (still no fp16) with the code above.

1 Like


Full working notebook from https://github.com/davda54/ada-hessian on CoLab
https://colab.research.google.com/drive/1IcBOCgereZbjASnlVjXPGbJqp-NYjK1L?usp=sharing

@Johnyquest the notebook is private

try this https://colab.research.google.com/drive/1IcBOCgereZbjASnlVjXPGbJqp-NYjK1L?usp=sharing

1 Like

Changed the privacy setting. You should be able to access it now.

Quick experiment update

I’ve been testing AdaHessian with a 4 layer Transformer (2encoder/2decoder layers) for a translation task. Translation performance improved from 0.4723 chrF (AdamW baseline) to 0.4756 chrF.

Weights and Biases experiment runs

All tests were done with @LessW202’s implementation using OptimWrapper, will try get a more native fastai version working in the next day or two

Hyperparameters

After playing around with the hyperparameters I found the following to work well with fit_one_cycle:

lr: 1e-2 (paper used 15e-2, 1e-2 was the sweet spot for me, higher or lower degraded performance)
wd: 1e-4 (same as paper, my experiments with 0. and 1e-3 degraded performance)
Betas: (0.9, 0.999) (paper used (0.9, 0.98) for their Transformer experiment, but for me setting Beta2 to 0.98 really degraded performance)

  • Note that fit_one_cycle cycles Beta1 from 0.95->0.85->0.95

Block size: 32 (same as paper, in my testing block size of 1 just matched AdamW performance)

Also this was all trained without mixed precision for now. Batch size of 16 was used as it was the largest I could do without fp16.

Thoughts

One concern I have is the speed/performance trade-off. Because it is approx 2x slower than other optimizers it really slows down your iteration speed when you are testing a new model/technique. Right now I’m thinking I might use it just in the final stages of experimentation, when everything else is working nicely and you’re trying to squeeze the last little bit of performance from the model.

3 Likes

Paper Presentation from the Authors!

I have an offer from the authors of the AdaHessian optimizer paper to present their work and answer any questions the community might have.

Right now I’m thinking a presentation and Q&A at Thursday (27th) at 9am Pacific Time would be a good time to try and get as many folks globally at once. Can also look into recording it maybe

Question I’d love a response to: What format/medium would people prefer? A zoom call? Or a presentation via the Discord server? Other?

3 Likes

Paper Presentation Invite

Come join the AdaHessian authors for an explanation of the AdaHessian paper, learn about second-order methods and see a couple of demo notebooks to experiment with the optimizer

Time: Thursday Aug 27, 2020 09:00 AM Pacific Time (US and Canada)

Zoom Link:

Meeting ID: 861 1033 1308
Passcode: 492448

One tap mobile
+35316533895,86110331308#,0#,492448# Ireland
+35316533897,86110331308#,0#,492448# Ireland

Dial by your location
+353 1 653 3895 Ireland
+353 1 653 3897 Ireland
+353 1 653 3898 Ireland
+353 6 163 9031 Ireland
+353 1 240 8941 Ireland
+353 1 536 9320 Ireland
+1 312 626 6799 US (Chicago)
+1 346 248 7799 US (Houston)
+1 646 558 8656 US (New York)
+1 669 900 9128 US (San Jose)
+1 253 215 8782 US (Tacoma)
+1 301 715 8592 US (Germantown)

Meeting ID: 861 1033 1308
Passcode: 492448

Find your local number: https://us02web.zoom.us/u/kgQCAgaVs

11 Likes

Fastai-native AdaHessian implementation with Imagenette here. Still need to think a little more about whether the callback can be removed

5 Likes

Very excited about this! I had no idea the diagonal was the most important part. Found out about this late last night, so wasn’t prepared at all.

The robustness to the learning rate property looks exciting for rapid prototyping. Will be going through the tutorials soon.

Overhead of any kind is a problem, but I think an optimizer that uses second order methods opens up the possibility of further improvements.

1 Like

@marii thanks for joining!!

YouTube Recording

For anyone who couldn’t join the talk with the authors the recording is available! I would love to hear if you’d like to see more of these for some of the other techniques/papers used across fastai!

Fastai-ready Code:

AdaHessian.py

I managed to implement AdaHessian natively, without the callback :smiley: To use adahessian just do: from AdaHessian import adahessian

Fastai Native AdaHessian code - ImageNette demo notebook

AdaHessian Code

Optimizer wrapper (similar to LookAhead):

Not sure how elegant the @patch of _backward is withint the init, but it works

@log_args(but='opt')
class AdaHessianWrapper(Optimizer, GetAttr):
    "Wrap `opt` in a AdaHessian optimizer"
    _default='opt'
    def __init__(self, opt, block_length=32, n_acc=1, fp16=False):
        store_attr(self, 'opt,block_length,n_acc')
        self.acc_count=0
        @patch
        def _backward(self:Learner): self.loss.backward(create_graph=True)
        
    def step(self):
        self._accumulate_grads()
        params, gradsH = self._get_params_grad()
        hvs, v = self._get_hessian(params, gradsH)
        hutchinson_trace = self._get_trace(hvs, v)
        for i, (p,pg,state,hyper) in enumerate(self.opt.all_params(with_grad=True)):
            state['hutchinson_trace'] = hutchinson_trace[i]
            for cb in self.opt.cbs: state = self._update(state, cb(p, **{**state, **hyper}))
            self.opt.state[p] = state
            
    def zero_grad(self):
        self.opt.zero_grad()
            
    def clear_state(self):
        self.opt.clear_state()

    def state_dict(self):
        state = self.opt.state_dict()
        
    def clear_state(self):
        self.opt.clear_state()
    
    def load_state_dict(self, sd):
        self.opt.load_state_dict(sd)
    
    def _accumulate_grads(self):
        self.acc_count += 1
        if self.acc_count < self.n_acc: 
            raise CancelBatchException() #skip weight update
        else: self.acc_count=0
    
    def _get_params_grad(self):
        params, gradsH = [], []
        for p,*_ in self.opt.all_params(with_grad=True):
            params.append(p)
            gradsH.append(0. if p.grad is None else p.grad + 0.)
        return params, gradsH
            
    def _get_hessian(self, params, gradsH):
        device = params[0].device
        v = [torch.randint_like(p, high=2, device=device) for p in params]
        for v_i in v: v_i[v_i == 0] = -1
        hvs = torch.autograd.grad(gradsH, params, grad_outputs=v, only_inputs=True, retain_graph=False)
        return hvs, v
    
    def _get_trace(self, hvs, v):
        hutchinson_trace = []
        for hv, vi in zip(hvs, v):
            param_size = hv.size()

            if len(param_size) <= 1:  
                # For 1D tensor, e.g.,, bias, BatchNorm, LayerNorm etc.
                # Usually, you do not need to set spatial aveging for it, i.e., Hessian diagonal block size is 1 here.
                tmp_output = torch.abs(hv * vi)
                hutchinson_trace.append(tmp_output)

                # Of course, you can also use the same way as 2D tensor does to average your 1D tensor. 
                # tmp_output1 = torch.abs((hv * vi + 0.)).view(-1, self.block_length) # faltten to the N times self.block_length
                # tmp_output2 = torch.abs(torch.sum(tmp_output1, dim=[1])).view(-1) / float(self.block_length)
                # tmp_output3 = tmp_output2.repeat_interleave(self.block_length).view(param_size)
                # hutchinson_trace.append(tmp_output3)

            elif len(param_size) == 2: 
                # For 2D tensor, e.g., the matrix in the fully-connected layer.
                # This is a normal case for MLP, Transformer models. 
                # Usually, a spatial averaging needs to be used here to get the best result.
                # If you are not looking for the absolute best config, you may set it to be 1.
                # In all of our experiments, we sill get pretty good performance.
                tmp_output1 = torch.abs((hv * vi + 0.)).view(-1, self.block_length) # faltten to the N times self.block_length
                tmp_output2 = torch.abs(torch.sum(tmp_output1, dim=[1])).view(-1) / float(self.block_length)
                tmp_output3 = tmp_output2.repeat_interleave(self.block_length).view(param_size)
                hutchinson_trace.append(tmp_output3)
            elif len(param_size) == 3:
                # For 3D tensor, e.g., the 1D Conv layer.
                # This layer is usually used for Char-LM.

                # First Way:
                # Usually, you can set it to be the conv kernel size: in more details, for instance, your input/output channels are 20 and your kernel size is 5, 
                # then the 1D Conv kernel is in size 20x20x5, you can average along the final dim, i.e., the block_length = 5
                tmp_output = torch.abs(torch.sum(torch.abs(
                    hv * vi), dim=[2], keepdim=True)) / vi[0, 1].numel() # torch.sum() reduces the dim 2, i.e. the size 5

                # Second way:
                # Of course, you can also use the same self.block_length to average the spatival Hessian of 3D kernel.
                # tmp_output1 = torch.abs((hv * vi + 0.)).view(-1, self.block_length) # faltten to the N times self.block_length
                # tmp_output2 = torch.abs(torch.sum(tmp_output1, dim=[1])).view(-1) / float(self.block_length)
                # tmp_output3 = tmp_output2.repeat_interleave(self.block_length).view(param_size)
                # hutchinson_trace.append(tmp_output3)

            elif len(param_size) == 4:  
                # For 4D tensor, e.g, the 2D Conv layer
                # This layer is usually used for CV tasks.

                # First Way:
                # Usually, you can set it to be the conv kernel size: in more details, for instance, your input/output channels are 256 and your kernel size is 3x3, 
                # then the 2D Conv kernel is in size 20x20x3x3, you can average along the last two dims, , i.e., the block_length = 9
                tmp_output = torch.abs(torch.sum(torch.abs(
                    hv * vi), dim=[2, 3], keepdim=True)) / vi[0, 1].numel() # torch.sum() reduces the dim 2/3.
                hutchinson_trace.append(tmp_output)

                # Second way:
                # Of course, you can also use the same self.block_length to average the spatival Hessian of 4D kernel.
                # tmp_output1 = torch.abs((hv * vi + 0.)).view(-1, self.block_length) # faltten to the N times self.block_length
                # tmp_output2 = torch.abs(torch.sum(tmp_output1, dim=[1])).view(-1) / float(self.block_length)
                # tmp_output3 = tmp_output2.repeat_interleave(self.block_length).view(param_size)
                # hutchinson_trace.append(tmp_output3)
        return hutchinson_trace
    
    def _update(self, state, new=None):
        if new is None: return state
        if isinstance(new, dict): state.update(new)
        return state

State update bits:

def average_sqr_diag_hessian(p, sqr_mom, dampening=True, sqr_avg_diag_hessian=None, hutchinson_trace=None, **kwargs):
    if sqr_avg_diag_hessian is None: sqr_avg_diag_hessian = torch.zeros_like(p.grad.data)
    damp = 1-sqr_mom if dampening else 1.
    sqr_avg_diag_hessian.mul_(sqr_mom).addcmul_(hutchinson_trace, hutchinson_trace, value=damp)
    return {'sqr_avg_diag_hessian': sqr_avg_diag_hessian}

AdaHessian Step:

def adahessian_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg_diag_hessian, hessian_power, eps, **kwargs):
    "Step for AdaHessian with `lr` on `p`"
    debias1 = debias(mom,     1-mom,     step)
    debias2 = debias(sqr_mom, 1-sqr_mom, step)
    if hessian_power < 1:
        p.data.addcdiv_(grad_avg, ((sqr_avg_diag_hessian/debias2).sqrt() ** hessian_power) + eps, value = -lr / debias1)  
    else:
        p.data.addcdiv_(grad_avg, (sqr_avg_diag_hessian/debias2).sqrt() + eps, value = -lr / debias1)    
    return p
@log_args(to_return=True, but_as=Optimizer.__init__)
def AdaHessian(params, lr=0.15, hessian_power=1., hutchinson_trace=None, mom=0.9, sqr_mom=0.999, eps=1e-4, wd=1e-4, decouple_wd=True):
    "A `Optimizer` for AdaHessian"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [partial(average_grad, dampening=True), average_sqr_diag_hessian, step_stat, adahessian_step]
    return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, hessian_power=hessian_power, eps=eps, wd=wd)

And finally pass the optimizer to its wrapper

@delegates(AdaHessian)
def adahessian(p, lr=0.15, n_acc=1, block_length=32, hessian_power=1., mom=0.9, sqr_mom=0.999, eps=1e-4, wd=1e-4, **kwargs):
    "Convenience method for `AdaHessianWrapper` with `Adahessian`"
    return AdaHessianWrapper(AdaHessian(p, lr=lr, **kwargs), n_acc=n_acc, block_length=block_length)

TODO

  • Get MixedPrecision working (important!)
  • Speed/performance opportunities
  • Refactor opportunities
10 Likes

@morgan Any reason mixed precision isn’t working? Have you tried native mixed precision?

Yep I tried both, I think it’s because they both override ‘_backward’

I got it to train at one point but the loss was going to nan so something was off…

Looking at the slide at around 35:10: a ~7% change in accuracy across a 10x scaling in learning rate from 50% of optimal to 500%… okay this is in NLP, but if I understand this right and it’s applicable to vision etc., combine it w/ Smith & Conovaloff’s semi-supervised learning research, and automated finetuning for high-performance models seems very nearby.

Moreso, fast automated finetuning, since you can afford to use higher learning rates (on top of already high LRs enabled by Smith’s 1-Cycle Policy).

There’re already automated solutions out there… but it sounds like this makes it significantly easier by finding out how to get the computer to do smarter work – so it’s not just enterprise tools. That has some very real implications for robotics, and likely many other fields.

I look forward to testing this out.

1 Like