Usage of the dropout_mask

Really getting into the nitty, gritty details here as I plan on rewriting most of the stack for replicating the performance on the paper (

The paper describes a variational dropout mask as the follows:

A variant of this,
variational dropout (Gal & Ghahramani, 2016), samples a
binary dropout mask only once upon the first call and then
to repeatedly use that locked dropout mask for all repeated
connections within the forward and backward pass.

But in the code, the forward method of the LockedDropout.forward(…) class evaluates a dropout_mask that varies across each forward pass for different input tensors. See a small example here:

So, how are we staying true to the referenced variational dropout because it seems like the dropout mask varies from one forward pass to another?


It is constant across time steps within the RNN, not across forward passes. Take a look at the fastai code to see how it’s implemented.

1 Like

okay… so thnx for the pointer. I can convince myself that the dropout mask is constant across the time-steps. Also convinced myself that the recurrent connection in the LSTM is subject to the weight-drop layer using the same dropout mask.

One thing that occurred to me last night while writing the stack: would it make sense to redefine the embedded_output as a nn.Module of its own? Right now, it’s just a method that is called to condition the weights of the embedding matrix. I feel like semantically it’d be better to have this a layer of it’s own, just as-is the case with WeightDrop. Also, having it as a layer would mean this would show when doing model.summary().

Completely a matter of semantics though, so probably not important.

Yes if you can get that to work that would be much better! I just stole the code from Smerity’s AWD-LSTM, basically.

If you do have a go at this, be sure to create a few test cases under a range of both training and test situations to confirm it really is identical. It’s hard to get these kinds of things just right in my experience, and good tests are critical.