Hi everyone. Would someone tell me whether this is possible with a fastai2 Learner, and point me towards the right approach? I imagine it’s done with callbacks. I’m willing to study.
The issue is processing audio samples of many different lengths, a 50:1 length ratio. If I can group these into minibatches of similar length, then PyTorch can kinda-sorta approximate the numerical calcs using rectangular tensors. That would also be faster, but the actual problem is to apply PyTorch pooling intended for fixed lengths to data of various lengths.
My idea is to sort the samples by length, send minibatches of similar length to the model while accumulating gradients, and do the weight update only after all minibatches have been processed.
P.S. It’s not an RNN, so packing and padding seem not to apply.
There’s a GradientAccumulation callback available (not 100% sure on the spelling)
Hi Zachary @muellerzr . I have managed to trace my version of GradientAccumulation in PyCharm, and read the Callbacks documentation. But I’d like some advice on how to proceed.
GradientAccumulation skips the weight update and zero grad steps by throwing CancelBatchException. Every time a fixed number of minibatches have been processed, it allows a single minibatch training loop to finish.
What I want to do is skip weight update for all minibatches in an epoch, and at the end do one weight update for the whole epoch. Can you advise on the best way to handle this?
Is there a way to detect it is being called during the last minibatch of an epoch? Then CancelBatchException would be skipped and the training loop continues normally. (This method would be my preference.)
Or should it always skip the weight updates and at the end of the epoch manually call
As an aside, it seems that GradientAccumulation will throw away the gradient accumulation of the last minibatches in a set of epochs. I don’t know if this will be a problem in actual use.
Hi @Pomo, you just need to set the number of samples needed before
GradientAccumulation is allowed to update weights.
In your case, if you just set it to the length of your dataset, it should directly work.
The trouble is, that’s just too simple.