[SOLVED] [AWD-LSTM] How Backpropagation Works for WeightDropout Module? (Not Sure if Its a Bug)

I am reviewing AWD-LSTM model from fastai2 module. It got me to a question related to WeightDropout. From my understanding, it wraps on top of any nn.Module and apply dropout mask(s) on the target weight(s).

As I inspected closer, I found that the target weights doesn’t yield any gradient in back-propagation while other weights do have gradients computed. It drives me to question:
In WeightDropout, how are gradients computed and propagated to the target weight(s)?


To illustrate my point, you can run the following code:

import torch
import torch.nn as nn
from fastai2.text.models.awdlstm import WeightDropout

lstm = nn.LSTM(3, 5, batch_first = True)
# target weight: weight_hh_l0
# non-target weight: weight_ih_l0
lstm_dp = WeightDropout(lstm, weight_p = 0.8, layer_names='weight_hh_l0')

test_input = torch.randn(8, 20, 3)  # (batch size, seq length, input dim)
test_h = torch.randn(1, 8, 5)
test_c = test_h.data
output, (h, c) = lstm_dp(test_input, (test_h, test_c))

loss = output.sum()
loss.backward()

Check out non-target weights, they have gradients computed:

In [4]: lstm_dp.module.weight_ih_l0.grad                                                                                               
Out[4]: 
tensor([[  6.7713,   1.6536,  -6.0267],
        [  6.0718,   9.1356,  -2.9828],
        [ -4.6106,  -6.5487,   8.9050],
        [ -5.4979,  -2.3226,   0.9387],
        [ -3.4466,   3.1723,  -2.4880],
        [  0.5058,  -1.8952,   0.3562],
        [ -1.5476,  -0.1757,   1.5508],
        [  4.5629,   2.8708,   1.5839],
        [  0.2175,   1.9155,  -0.6714],
        [  0.1650,   0.6840,   0.1294],
        [ 19.2782,  18.3211,  -7.3397],
        [-19.1672, -11.2854,  10.4764],
        [  5.6528,  -2.5588,  -1.9340],
        [  1.1073,  10.5333,   0.9745],
        [  2.7394,   0.5985,  -1.1770],
        [  2.1402,  -0.8321,   1.0183],
        [  3.1592,   6.3710,  -3.9283],
        [ -4.2480,  -5.9663,   8.3711],
        [ -2.5984,  -0.1586,   1.3106],
        [ -1.8626,   2.3050,  -1.2497]])

Check out target weights in both lstm_dp.module.{weight} and lstm_dp.{weight}_raw, they don’t have gradients computed:

In [5]: lstm_dp.module.weight_hh_l0.grad                                                                                               

In [6]: lstm_dp.weight_hh_l0_raw.grad                                                                                                  

After a brief investigation, I can get back the gradients by doing a small change on the WeightDropout implementation as follows. (It’s a bit suspicious to me. Is it actually a bug?)

class WeightDropout(Module):
    "A module that warps another layer in which some weights will be replaced by 0 during training."

    def __init__(self, module, weight_p, layer_names='weight_hh_l0'):
        self.module,self.weight_p,self.layer_names = module,weight_p,L(layer_names)
        for layer in self.layer_names:
            #Makes a copy of the weights of the selected layers.
            w = getattr(self.module, layer)
            delattr(self.module, layer)
            self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
            setattr(self.module, layer, F.dropout(w.data, p=self.weight_p, training=False))
            if isinstance(self.module, (nn.RNNBase, nn.modules.rnn.RNNBase)):
                self.module.flatten_parameters = self._do_nothing

    def _setweights(self):
        "Apply dropout to the raw weights."
        for layer in self.layer_names:
            raw_w = getattr(self, f'{layer}_raw')
            # CHANGE: raw_w.data --> raw_w
            setattr(self.module, layer, F.dropout(raw_w, p=self.weight_p, training=self.training))  
 
    def forward(self, *args):
        self._setweights()
        with warnings.catch_warnings():
            #To avoid the warning that comes because the weights aren't flattened.
            warnings.simplefilter("ignore")
            return self.module.forward(*args)

    def reset(self):
        for layer in self.layer_names:
            raw_w = getattr(self, f'{layer}_raw')
            setattr(self.module, layer, F.dropout(raw_w.data, p=self.weight_p, training=False))  
        if hasattr(self.module, 'reset'): self.module.reset()

    def _do_nothing(self): pass

with the above change and run the example again, I can get the gradients back on lstm_dp.{weight}_raw:

In [53]: lstm_dp.weight_hh_l0_raw.grad                                                                                                
Out[53]: 
tensor([[  0.0000, -22.5616,  -0.0000,   0.0000,   0.0000],
        [ -0.0000,   0.0000,   0.0000,  -0.0000,  -0.0000],
        [ -0.0000,   0.0000,   0.0000,  -0.0000,  -0.0000],
        [  0.0000,  -0.0000,  -0.0000,   0.0000,   0.0000],
        [ 10.8977, -14.0774,  -0.0000,   0.0000,   0.0000],
        [ 40.2919,  -0.0000,  -0.0000,   0.0000,   0.0000],
        [ -6.0390,   0.0000,   1.8585,  -0.0000,  -0.0000],
        [-11.2179,   0.0000,   8.1154,   0.0000,  -0.0000],
        [  0.0000,  -0.0000,   0.1269,   0.0000,   0.0000],
        [  0.0000,  -0.0000,  -1.4096,   0.0000,   0.0000],
        [ -0.0000,   0.0000,   0.0000,  -0.0000, -32.2400],
        [ -0.0000,  12.8236,   0.0000,  -0.0000,  -3.6243],
        [ -0.0000,   0.0000,  14.6292,  -0.0000,  -0.0000],
        [ -0.0000,  35.2747,  14.0794,  -0.0000,  -0.0000],
        [ -0.0000,  47.7815,   0.0000,  -0.0000,  -0.0000],
        [  0.0000,  -0.0000,  -0.0000,   0.0000,   0.0000],
        [ -0.0000,   0.0000,   0.0000,  -0.0000,  -0.0000],
        [ -0.0000,   0.0000,   0.0000,  -0.8990,  -0.0000],
        [  0.0000,  -0.0000,  -0.2122,   2.9238,   0.0000],
        [  0.0000, -14.2455,  -0.0000,   0.0000,   9.6305]])

In [54]: lstm_dp.module.weight_hh_l0                                                                                                  
Out[54]: 
tensor([[ 0.0000, -0.1062,  0.0000, -0.0000, -0.0000],
        [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
        [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
        [-1.0074,  0.5093, -0.0000, -0.0000, -0.0000],
        [-0.7589,  0.0000,  0.0000,  0.0000, -0.0000],
        [ 1.7357, -0.0000, -0.2286,  0.0000,  0.0000],
        [ 0.4990, -0.0000,  1.6599,  0.0000, -0.0000],
        [ 0.0000,  0.0000, -1.7126,  0.0000,  0.0000],
        [-0.0000, -0.0000, -1.7145,  0.0000, -0.0000],
        [-0.0000,  0.0000, -0.0000, -0.0000, -0.1598],
        [-0.0000,  0.4909,  0.0000, -0.0000, -0.2315],
        [ 0.0000, -0.0000, -1.8416,  0.0000, -0.0000],
        [ 0.0000, -0.3519,  0.1470,  0.0000, -0.0000],
        [-0.0000, -2.2106,  0.0000,  0.0000, -0.0000],
        [-0.0000, -0.0000, -0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0000, -1.3198, -0.0000],
        [-0.0000,  0.0000, -0.2853,  0.4752, -0.0000],
        [ 0.0000,  1.3968, -0.0000,  0.0000,  1.7818]], grad_fn=<MulBackward0>)

As a follow up, the issue is solved.

The target weight should have gradients and it is fixed by the PR

1 Like