Something Similar to Dropout but for Weights?

Hello all,

Dropout works by randomly excluding some neurons during the training of a neural net, leading to a more robust network with better generalization.

Well, couldn’t we try something similar but instead of dropping neurons, we’d be zeroing out some of the weights and so basically dropping the connection between some of the neurons?

Something like this:

Screenshot from 2020-11-19 18-02-16

In other words, a random set of elements of the weight matrix would be initially set to zero (unlike dropout though, only before the first epoch. The rest of the training & validation would remain the same).

From there, we’d have two options:

  1. Train the network as usual, effectively allowing it to change the weights initially set to zero if there is performance to be gained from doing so.
  2. Don’t update such weights (by zeroing out their gradient as well), deeming them non-trainable parameters.

Tabular datasets are where I believe this method would excel at for a number of reasons:

  1. When dealing with structured datasets, it is often the interaction between only a few of the features that matters. In that regards, this method could be very helpful in that it would aid the neural net in exploring the relation between the dependent variable and only a subset of the features.
  2. There might be different ways to get to the dependent variable, with each one giving a slightly different answer (high variance). For instance, maybe the interaction between features 1 and 2 is very predictive of the y-value and so is the interaction between features 3, 4, and 5. Setting the weights to zero initially would enable the model to make different (accurate) predictions using various features, essentially behaving like an ensemble model (I believe this is what dropout tries to achieve as well).
  3. Weights that are zero are most probably going to remain close to zero (especially if there’s L1 or L2 regularization), resulting in a less complex model.

Like any other algorithm, there are obviously drawbacks as well:

  1. If it’s a very complex dataset, the model might not be able to realize its full potential (the other end of the first benefit)
  2. Proper weight initialization and setting a big chunk of the weights to the same value will mess up the mean, STD, etc.
  3. Probably lot more I haven’t thought of yet.

I tried implementing this in Colab (there’s a lot of cleaning and refactoring needed) and it actually worked surprisingly well: On the blue book for bulldozers dataset, I was able to achieve and RMSE of less than 0.220, better than what I could get with an ensemble of an RF, Extra Trees, XGBoost, and fastai’s tabular model. I’ll try this on other datasets as well and hopefully I’ll see the same improvement!

P.S: I’d very much appreciate it if some of the more knowledgeable users would share any of their ideas, thoughts, or opinions about my code or the technique I proposed.

Sincerely,
Borna

This is known as DropConnect: https://paperswithcode.com/method/dropconnect

Thank you very much! I tried searching various terms but none of them yielded DropConnect. Is there a reason it’s nowhere as widely used as dropout even though ones gets the impression from the literature that it usually performs better (and also the fact that it’s a generalized form of dropout)?

Thank you in advance!

Not sure. I’ve mostly used pytorch-image-models for training and they did use DropConnect although now they’ve replaced it by DropPath, https://paperswithcode.com/method/droppath

Oh, I see, thank you! I did find a few examples of DropConnect with CNNs but, again, I think structured datasets are where it has the potential to be really useful. I will experiment with other (tabular) datasets and keep you posted.

Have a good day (or evening)!

AWD-LSTM the one that is used in ULMFiT as part of fastai is using dropconnect, you can check the model out

https://github.com/fastai/fastai/blob/master/nbs/32_text.models.awdlstm.ipynb

class WeightDropout(Module) is what you are looking for

NLP is something I don’t really follow (solely because after almost a year, I still can’t wrap my head around it) so I’m not very familiar with its architectures and the various algorithms they use. Thank you for the link!