AdaHessian

Quick experiment update

I’ve been testing AdaHessian with a 4 layer Transformer (2encoder/2decoder layers) for a translation task. Translation performance improved from 0.4723 chrF (AdamW baseline) to 0.4756 chrF.

Weights and Biases experiment runs

All tests were done with @LessW202’s implementation using OptimWrapper, will try get a more native fastai version working in the next day or two

Hyperparameters

After playing around with the hyperparameters I found the following to work well with fit_one_cycle:

lr: 1e-2 (paper used 15e-2, 1e-2 was the sweet spot for me, higher or lower degraded performance)
wd: 1e-4 (same as paper, my experiments with 0. and 1e-3 degraded performance)
Betas: (0.9, 0.999) (paper used (0.9, 0.98) for their Transformer experiment, but for me setting Beta2 to 0.98 really degraded performance)

  • Note that fit_one_cycle cycles Beta1 from 0.95->0.85->0.95

Block size: 32 (same as paper, in my testing block size of 1 just matched AdamW performance)

Also this was all trained without mixed precision for now. Batch size of 16 was used as it was the largest I could do without fp16.

Thoughts

One concern I have is the speed/performance trade-off. Because it is approx 2x slower than other optimizers it really slows down your iteration speed when you are testing a new model/technique. Right now I’m thinking I might use it just in the final stages of experimentation, when everything else is working nicely and you’re trying to squeeze the last little bit of performance from the model.

3 Likes

Paper Presentation from the Authors!

I have an offer from the authors of the AdaHessian optimizer paper to present their work and answer any questions the community might have.

Right now I’m thinking a presentation and Q&A at Thursday (27th) at 9am Pacific Time would be a good time to try and get as many folks globally at once. Can also look into recording it maybe

Question I’d love a response to: What format/medium would people prefer? A zoom call? Or a presentation via the Discord server? Other?

3 Likes

Paper Presentation Invite

Come join the AdaHessian authors for an explanation of the AdaHessian paper, learn about second-order methods and see a couple of demo notebooks to experiment with the optimizer

Time: Thursday Aug 27, 2020 09:00 AM Pacific Time (US and Canada)

Zoom Link:

Meeting ID: 861 1033 1308
Passcode: 492448

One tap mobile
+35316533895,86110331308#,0#,492448# Ireland
+35316533897,86110331308#,0#,492448# Ireland

Dial by your location
+353 1 653 3895 Ireland
+353 1 653 3897 Ireland
+353 1 653 3898 Ireland
+353 6 163 9031 Ireland
+353 1 240 8941 Ireland
+353 1 536 9320 Ireland
+1 312 626 6799 US (Chicago)
+1 346 248 7799 US (Houston)
+1 646 558 8656 US (New York)
+1 669 900 9128 US (San Jose)
+1 253 215 8782 US (Tacoma)
+1 301 715 8592 US (Germantown)

Meeting ID: 861 1033 1308
Passcode: 492448

Find your local number: https://us02web.zoom.us/u/kgQCAgaVs

11 Likes

Fastai-native AdaHessian implementation with Imagenette here. Still need to think a little more about whether the callback can be removed

5 Likes

Very excited about this! I had no idea the diagonal was the most important part. Found out about this late last night, so wasn’t prepared at all.

The robustness to the learning rate property looks exciting for rapid prototyping. Will be going through the tutorials soon.

Overhead of any kind is a problem, but I think an optimizer that uses second order methods opens up the possibility of further improvements.

1 Like

@marii thanks for joining!!

YouTube Recording

For anyone who couldn’t join the talk with the authors the recording is available! I would love to hear if you’d like to see more of these for some of the other techniques/papers used across fastai!

Fastai-ready Code:

AdaHessian.py

I managed to implement AdaHessian natively, without the callback :smiley: To use adahessian just do: from AdaHessian import adahessian

Fastai Native AdaHessian code - ImageNette demo notebook

AdaHessian Code

Optimizer wrapper (similar to LookAhead):

Not sure how elegant the @patch of _backward is withint the init, but it works

@log_args(but='opt')
class AdaHessianWrapper(Optimizer, GetAttr):
    "Wrap `opt` in a AdaHessian optimizer"
    _default='opt'
    def __init__(self, opt, block_length=32, n_acc=1, fp16=False):
        store_attr(self, 'opt,block_length,n_acc')
        self.acc_count=0
        @patch
        def _backward(self:Learner): self.loss.backward(create_graph=True)
        
    def step(self):
        self._accumulate_grads()
        params, gradsH = self._get_params_grad()
        hvs, v = self._get_hessian(params, gradsH)
        hutchinson_trace = self._get_trace(hvs, v)
        for i, (p,pg,state,hyper) in enumerate(self.opt.all_params(with_grad=True)):
            state['hutchinson_trace'] = hutchinson_trace[i]
            for cb in self.opt.cbs: state = self._update(state, cb(p, **{**state, **hyper}))
            self.opt.state[p] = state
            
    def zero_grad(self):
        self.opt.zero_grad()
            
    def clear_state(self):
        self.opt.clear_state()

    def state_dict(self):
        state = self.opt.state_dict()
        
    def clear_state(self):
        self.opt.clear_state()
    
    def load_state_dict(self, sd):
        self.opt.load_state_dict(sd)
    
    def _accumulate_grads(self):
        self.acc_count += 1
        if self.acc_count < self.n_acc: 
            raise CancelBatchException() #skip weight update
        else: self.acc_count=0
    
    def _get_params_grad(self):
        params, gradsH = [], []
        for p,*_ in self.opt.all_params(with_grad=True):
            params.append(p)
            gradsH.append(0. if p.grad is None else p.grad + 0.)
        return params, gradsH
            
    def _get_hessian(self, params, gradsH):
        device = params[0].device
        v = [torch.randint_like(p, high=2, device=device) for p in params]
        for v_i in v: v_i[v_i == 0] = -1
        hvs = torch.autograd.grad(gradsH, params, grad_outputs=v, only_inputs=True, retain_graph=False)
        return hvs, v
    
    def _get_trace(self, hvs, v):
        hutchinson_trace = []
        for hv, vi in zip(hvs, v):
            param_size = hv.size()

            if len(param_size) <= 1:  
                # For 1D tensor, e.g.,, bias, BatchNorm, LayerNorm etc.
                # Usually, you do not need to set spatial aveging for it, i.e., Hessian diagonal block size is 1 here.
                tmp_output = torch.abs(hv * vi)
                hutchinson_trace.append(tmp_output)

                # Of course, you can also use the same way as 2D tensor does to average your 1D tensor. 
                # tmp_output1 = torch.abs((hv * vi + 0.)).view(-1, self.block_length) # faltten to the N times self.block_length
                # tmp_output2 = torch.abs(torch.sum(tmp_output1, dim=[1])).view(-1) / float(self.block_length)
                # tmp_output3 = tmp_output2.repeat_interleave(self.block_length).view(param_size)
                # hutchinson_trace.append(tmp_output3)

            elif len(param_size) == 2: 
                # For 2D tensor, e.g., the matrix in the fully-connected layer.
                # This is a normal case for MLP, Transformer models. 
                # Usually, a spatial averaging needs to be used here to get the best result.
                # If you are not looking for the absolute best config, you may set it to be 1.
                # In all of our experiments, we sill get pretty good performance.
                tmp_output1 = torch.abs((hv * vi + 0.)).view(-1, self.block_length) # faltten to the N times self.block_length
                tmp_output2 = torch.abs(torch.sum(tmp_output1, dim=[1])).view(-1) / float(self.block_length)
                tmp_output3 = tmp_output2.repeat_interleave(self.block_length).view(param_size)
                hutchinson_trace.append(tmp_output3)
            elif len(param_size) == 3:
                # For 3D tensor, e.g., the 1D Conv layer.
                # This layer is usually used for Char-LM.

                # First Way:
                # Usually, you can set it to be the conv kernel size: in more details, for instance, your input/output channels are 20 and your kernel size is 5, 
                # then the 1D Conv kernel is in size 20x20x5, you can average along the final dim, i.e., the block_length = 5
                tmp_output = torch.abs(torch.sum(torch.abs(
                    hv * vi), dim=[2], keepdim=True)) / vi[0, 1].numel() # torch.sum() reduces the dim 2, i.e. the size 5

                # Second way:
                # Of course, you can also use the same self.block_length to average the spatival Hessian of 3D kernel.
                # tmp_output1 = torch.abs((hv * vi + 0.)).view(-1, self.block_length) # faltten to the N times self.block_length
                # tmp_output2 = torch.abs(torch.sum(tmp_output1, dim=[1])).view(-1) / float(self.block_length)
                # tmp_output3 = tmp_output2.repeat_interleave(self.block_length).view(param_size)
                # hutchinson_trace.append(tmp_output3)

            elif len(param_size) == 4:  
                # For 4D tensor, e.g, the 2D Conv layer
                # This layer is usually used for CV tasks.

                # First Way:
                # Usually, you can set it to be the conv kernel size: in more details, for instance, your input/output channels are 256 and your kernel size is 3x3, 
                # then the 2D Conv kernel is in size 20x20x3x3, you can average along the last two dims, , i.e., the block_length = 9
                tmp_output = torch.abs(torch.sum(torch.abs(
                    hv * vi), dim=[2, 3], keepdim=True)) / vi[0, 1].numel() # torch.sum() reduces the dim 2/3.
                hutchinson_trace.append(tmp_output)

                # Second way:
                # Of course, you can also use the same self.block_length to average the spatival Hessian of 4D kernel.
                # tmp_output1 = torch.abs((hv * vi + 0.)).view(-1, self.block_length) # faltten to the N times self.block_length
                # tmp_output2 = torch.abs(torch.sum(tmp_output1, dim=[1])).view(-1) / float(self.block_length)
                # tmp_output3 = tmp_output2.repeat_interleave(self.block_length).view(param_size)
                # hutchinson_trace.append(tmp_output3)
        return hutchinson_trace
    
    def _update(self, state, new=None):
        if new is None: return state
        if isinstance(new, dict): state.update(new)
        return state

State update bits:

def average_sqr_diag_hessian(p, sqr_mom, dampening=True, sqr_avg_diag_hessian=None, hutchinson_trace=None, **kwargs):
    if sqr_avg_diag_hessian is None: sqr_avg_diag_hessian = torch.zeros_like(p.grad.data)
    damp = 1-sqr_mom if dampening else 1.
    sqr_avg_diag_hessian.mul_(sqr_mom).addcmul_(hutchinson_trace, hutchinson_trace, value=damp)
    return {'sqr_avg_diag_hessian': sqr_avg_diag_hessian}

AdaHessian Step:

def adahessian_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg_diag_hessian, hessian_power, eps, **kwargs):
    "Step for AdaHessian with `lr` on `p`"
    debias1 = debias(mom,     1-mom,     step)
    debias2 = debias(sqr_mom, 1-sqr_mom, step)
    if hessian_power < 1:
        p.data.addcdiv_(grad_avg, ((sqr_avg_diag_hessian/debias2).sqrt() ** hessian_power) + eps, value = -lr / debias1)  
    else:
        p.data.addcdiv_(grad_avg, (sqr_avg_diag_hessian/debias2).sqrt() + eps, value = -lr / debias1)    
    return p
@log_args(to_return=True, but_as=Optimizer.__init__)
def AdaHessian(params, lr=0.15, hessian_power=1., hutchinson_trace=None, mom=0.9, sqr_mom=0.999, eps=1e-4, wd=1e-4, decouple_wd=True):
    "A `Optimizer` for AdaHessian"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [partial(average_grad, dampening=True), average_sqr_diag_hessian, step_stat, adahessian_step]
    return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, hessian_power=hessian_power, eps=eps, wd=wd)

And finally pass the optimizer to its wrapper

@delegates(AdaHessian)
def adahessian(p, lr=0.15, n_acc=1, block_length=32, hessian_power=1., mom=0.9, sqr_mom=0.999, eps=1e-4, wd=1e-4, **kwargs):
    "Convenience method for `AdaHessianWrapper` with `Adahessian`"
    return AdaHessianWrapper(AdaHessian(p, lr=lr, **kwargs), n_acc=n_acc, block_length=block_length)

TODO

  • Get MixedPrecision working (important!)
  • Speed/performance opportunities
  • Refactor opportunities
10 Likes

@morgan Any reason mixed precision isn’t working? Have you tried native mixed precision?

Yep I tried both, I think it’s because they both override ‘_backward’

I got it to train at one point but the loss was going to nan so something was off…

Looking at the slide at around 35:10: a ~7% change in accuracy across a 10x scaling in learning rate from 50% of optimal to 500%… okay this is in NLP, but if I understand this right and it’s applicable to vision etc., combine it w/ Smith & Conovaloff’s semi-supervised learning research, and automated finetuning for high-performance models seems very nearby.

Moreso, fast automated finetuning, since you can afford to use higher learning rates (on top of already high LRs enabled by Smith’s 1-Cycle Policy).

There’re already automated solutions out there… but it sounds like this makes it significantly easier by finding out how to get the computer to do smarter work – so it’s not just enterprise tools. That has some very real implications for robotics, and likely many other fields.

I look forward to testing this out.

1 Like

Thanks a lot for recording this! I was so disappointed I couldn’t make it

3 Likes

Thank you Morgan. The video was very informative. It also gave me a nice peep into Randomized Linear Algebra.

3 Likes

Quick update

Still working on getting native Mixed Precision working, was stuck for a while, but I think I know how to do it now.

I’ve posted a question on the pytorch forums related to GradScaler.unscale_, but regardless of the answer I should be able to make a workaround happen. Will share here when FP16 is working!

3 Likes