Using F.dropout to copy parameters

(Sachin) #1

I was looking at the code for AWD LSTMs (shown below) and was confused about the use of F.dropout with training=False.

WEIGHT_HH = 'weight_hh_l0'

class WeightDropout(nn.Module):
    def __init__(self, module, weight_p=[0.], layer_names=[WEIGHT_HH]):
        super().__init__()
        self.module,self.weight_p,self.layer_names = module,weight_p,layer_names
        for layer in self.layer_names:
            #Makes a copy of the weights of the selected layers.
            w = getattr(self.module, layer)
            self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
            self.module._parameters[layer] = F.dropout(w, p=self.weight_p, training=False)

    def _setweights(self):
        for layer in self.layer_names:
            raw_w = getattr(self, f'{layer}_raw')
            self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training)

I asked the question on github and the answer was that it was to initially copy the weights to its ‘_raw’ version.

So my questions are:

  1. Why use F.dropout with training=False. Can’t you simply do self.module._parameters[layer] = w.clone().
  2. Do you even need to copy the weights across like this because getattr(self.module, layer) is exactly the same as self.module._parameters[layer]. Rendering the last line in __init__ unnecessary?
  3. What is the purpose of having a weight_hh_l0_raw version? Is it because the line self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training) overwrites the weights and you want to preserve the weights?
0 Likes

Multilingual ULMFiT
Multilingual ULMFiT
(Karl) #2

The purpose of weight drop is to apply dropout to the model weights as opposed to the activations. When we do this, we want a single dropout mask to be applied at all time steps. The WeightDropout class saves a copy of the true weights, then applies a single dropout mask to the weights on every forward pass.

0 Likes

(Sachin) #3

Hey Karl. Thanks for answering but I’m afraid this doesn’t answer my question. I know what weightDropout does. Just not sure what F.dropout is doing in this case. (See the 3 questions listed above).

0 Likes

(Mike Tian-Jian Jiang) #4

I did some digging that may be related to this question, please check this post: Multilingual ULMFiT

It was something about QRNN but I’m not so sure it is still required now.

0 Likes