Model Pruning in fast.ai

Hello, I’m not sure if this is the right forum for this post but is there any ongoing work on introducing model pruning support in fast ai ?

I found it really easy to implement in pytorch. I was able to train a mnist model and prune it to a 90% sparsity with very minimal degradation in accuracy. It was really surprising. If you would like to see how easy it is to prune a model in pytorch please refer to my colab notebook here .

I referred to the popular “To Prune or not to Prune” algorithm by the stanford folks and what is incidentally used by Tensorflow in their model optimisation toolkit too.

Now the effect of pruning as done here using this algorithm (fine grained pruning) won’t give performance benefits due to the lack of sparse kernel support in pytorch right now but I imagine it’s going to be popular real soon due to immense potential for model size/ inference speed savings.

I would be happy to talk about taking this further and adding support for this in fast ai if the community thinks it might be useful. I had some ideas of combining pruning with quantisation and getting some really interesting speedups.

6 Likes

Hi !

The only discussion about pruning I can think of is this one. I don’t see any reason why pruning couldn’t be added to fastai if it is operational.

I am also interested in model pruning/compression. I am currently making experiments where I combine the technique you describe here (which I prefer calling sparsifying) with kernel-level pruning to get benefits in terms of parameters/speed. Currently this works well, I hope we’ll see developments in that way.

1 Like

Also, I think there is a small mistake in the code when you define n:

    n = (current_step - starting_step)/span

Shouldn’t it be:

    n = (end_step - starting_step)/span

Because if I print the sparsity of your model with your code, it is already at 89% after 1 epoch. However, this shows that even with drastical pruning like this, it still works really well :slight_smile:

2 Likes

Hello ! Glad to meet another pruning/compression enthusiast, I was personally quite surprised to see that it worked so well ! About the formula, I have taken it from the Prune or Not to prune paper and I think it is meant to be the way it was. You might notice the sparsity goes up really quickly after epoch 1 but then prunes fewer and fewer weights in later epochs.

The authors say this about the formula : “The intuition behind this sparsity function in equation (1) is to prune the network rapidly in the initial phase when the redundant connections are abundant and
gradually reduce the number of weights being pruned each time as there are fewer and fewer weights
remaining in the network, as illustrated in Figure 1.”

Kernel level pruning is fascinating, however I read somewhere that sparsification results in lesser accuracy dips than kernel level pruning, The paper which gives a good set of experiments on this is given here.

I was thinking of using sparsification and modifying the block sparse kernels that open AI open sourced to play well with pytorch and then measuring speed ups, of course right now the speed ups are theoretical. If someone is able to do that, this method could be useful right now. Sadly pytorch currently doesn’t support quantisation (though work is ongoing !).

On more careful reading the paper, you’re right ! It should indeed be what you stated, my bad ! :slight_smile:

Thank you !