Cuda Error: invalid configuration argument during Weight Dropout in AWD-LSTM

I’m trying to build a siamese sentence encoder and I’m getting a strange error that I’m not sure how to debug. The Siamese Sentence Encoder I have built simply takes the ULMFIT SentenceEncoder and the masked_concatentated_pooling function and combines them into a single module. I am feeding a pair of sentences (s1,s2) through the ULMFiT model one after the other and pooling the outputs to get embeddings (u,v) that can then be used for natural language inference or semantic similarity.

class SiameseSentenceEncoder(Module):
    """Create an encoder using `module` that can process a pair of input sentences and then
       produce output embeddings for each sentence via masked concatenated pooling."""
    def __init__(self, bptt, module, pad_idx=1, max_len=None): store_attr(self, 'bptt,module,pad_idx,max_len')
    def reset(self): getattr(self.module, 'reset', noop)()

    def sentence_encoder(self, input):
        bs,sl = input.size()
        self.reset()
        mask = input == self.pad_idx
        outs,masks = [],[]
        for i in range(0, sl, self.bptt):
            #Note: this expects that sequence really begins on a round multiple of bptt
            real_bs = (input[:,i] != self.pad_idx).long().sum()
            o = self.module(input[:real_bs,i: min(i+self.bptt, sl)])
            if self.max_len is None or sl-i <= self.max_len:
                outs.append(o)
                masks.append(mask[:,i: min(i+self.bptt, sl)])
        outs = torch.cat([_pad_tensor(o, bs) for o in outs], dim=1)
        mask = torch.cat(masks, dim=1)
        return outs,mask

    def masked_concat_pool(self, output, mask):
        "Pool encoded sentences into one vector [last_hidden, max_pool, avg_pool]"
        lens = output.shape[1] - mask.long().sum(dim=1)
        last_lens = mask[:,-self.bptt:].long().sum(dim=1)
        avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)
        avg_pool.div_(lens.type(avg_pool.dtype)[:,None])
        max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]
        x = torch.cat([output[torch.arange(0, output.size(0)),-last_lens-1], max_pool, avg_pool], 1) #Concat pooling.
        return x

    def forward(self, s1, s2):
        out1, mask1 = self.sentence_encoder(s1)
        u = self.masked_concat_pool(out1, mask1)
        
        out2, mask2 = self.sentence_encoder(s2)
        v = self.masked_concat_pool(out2, mask2)
        
        return u, v, out1, out2
        
class SiameseLinearClassifier(Module):
    "Create a linear classifier with options for product and difference vectors"
    def __init__(self, dims, ps, bptt, include_diff, include_prod, y_range=None):
        if len(ps) != len(dims)-1: raise ValueError("Number of layers and dropout values do not match.")
        acts = [nn.ReLU(inplace=True)] * (len(dims) - 2) + [None]
        layers = [LinBnDrop(i, o, p=p, act=a) for i,o,p,a in zip(dims[:-1], dims[1:], ps, acts)]
        if y_range is not None: layers.append(SigmoidRange(*y_range))
        self.layers = nn.Sequential(*layers)
        self.bptt = bptt
        self.include_diff = include_diff
        self.include_prod = include_prod

    def forward(self, u, v, out1, out2):
        o = [u,v]
        if self.include_diff: o.append(torch.abs(u-v))
        if self.include_prod: o.append(u*v)
        o = torch.cat(o, 1)
        x = self.layers(o)
        return x, out1, out2

Training starts and seems to go along fine up to a point but then I get a Cuda error: invalid configuration argument that occurs during Weight Dropout in the middle layer of the LSTM. This error always seems to occur as the second sentence is being passed through the encoder.

epoch	train_loss	valid_loss	accuracy	time
0	1.371777	00:09
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-23-bce3cbb1ba4b> in <module>()
      1 lr = 1e-2
----> 2 learn.fit_one_cycle(1, lr)

17 frames
/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    111     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    112               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    114 
    115 # Cell

/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    198                     try:
    199                         self.epoch=epoch;          self('begin_epoch')
--> 200                         self._do_epoch_train()
    201                         self._do_epoch_validate()
    202                     except CancelEpochException:   self('after_cancel_epoch')

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in _do_epoch_train(self)
    173         try:
    174             self.dl = self.dls.train;                        self('begin_train')
--> 175             self.all_batches()
    176         except CancelTrainException:                         self('after_cancel_train')
    177         finally:                                             self('after_train')

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in all_batches(self)
    151     def all_batches(self):
    152         self.n_iter = len(self.dl)
--> 153         for o in enumerate(self.dl): self.one_batch(*o)
    154 
    155     def one_batch(self, i, b):

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in one_batch(self, i, b)
    157         try:
    158             self._split(b);                                  self('begin_batch')
--> 159             self.pred = self.model(*self.xb);                self('after_pred')
    160             if len(self.yb) == 0: return
    161             self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

<ipython-input-12-6dddd1191cb9> in forward(self, *inputs)
      7         for module in self._modules.values():
      8             if type(inputs) == tuple:
----> 9                 inputs = module(*inputs)
     10             else:
     11                 inputs = module(inputs)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

<ipython-input-11-6bb61207768e> in forward(self, s1, s2)
     39         u = self.masked_concat_pool(out1, mask1)
     40 
---> 41         out2, mask2 = self.sentence_encoder(s2)
     42         v = self.masked_concat_pool(out2, mask2)
     43 

<ipython-input-11-6bb61207768e> in sentence_encoder(self, input)
     17             #Note: this expects that sequence really begins on a round multiple of bptt
     18             real_bs = (input[:,i] != self.pad_idx).long().sum()
---> 19             o = self.module(input[:real_bs,i: min(i+self.bptt, sl)])
     20             if self.max_len is None or sl-i <= self.max_len:
     21                 outs.append(o)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/fastai2/text/models/awdlstm.py in forward(self, inp, from_embeds)
    105         new_hidden = []
    106         for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):
--> 107             output, new_h = rnn(output, self.hidden[l])
    108             new_hidden.append(new_h)
    109             if l != self.n_layers - 1: output = hid_dp(output)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/fastai2/text/models/awdlstm.py in forward(self, *args)
     48 
     49     def forward(self, *args):
---> 50         self._setweights()
     51         with warnings.catch_warnings():
     52             #To avoid the warning that comes because the weights aren't flattened.

/usr/local/lib/python3.6/dist-packages/fastai2/text/models/awdlstm.py in _setweights(self)
     45         for layer in self.layer_names:
     46             raw_w = getattr(self, f'{layer}_raw')
---> 47             setattr(self.module, layer, F.dropout(raw_w.data, p=self.weight_p, training=self.training))
     48 
     49     def forward(self, *args):

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in dropout(input, p, training, inplace)
    934     return (_VF.dropout_(input, p, training)
    935             if inplace
--> 936             else _VF.dropout(input, p, training))
    937 
    938 

RuntimeError: CUDA error: invalid configuration argument

In the debugger the variables all appear to be correct and if I run the functions _VF.dropout_(…) and _VF.dropout(…) themselves, they run without issue so there doesn’t seem to be a problem with the data.

> /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py(936)dropout()
    934     return (_VF.dropout_(input, p, training)
    935             if inplace
--> 936             else _VF.dropout(input, p, training))
    937 
    938 

ipdb> input.size()
torch.Size([4608, 1152])
ipdb> training
True
ipdb> _VF.dropout(input, p, training)
tensor([[-1.1071e-04,  7.4137e-02,  1.4966e-01,  ..., -1.5828e-02,
         -1.3241e-01, -1.7480e-01],
        [-3.0579e-02,  0.0000e+00, -1.2044e-01,  ...,  5.3263e-02,
          2.3682e-02, -1.9352e-01],
        [ 1.4331e-01,  2.0492e-01,  9.1797e-02,  ...,  9.8145e-02,
          3.7903e-02, -1.1857e-01],
        ...,
        [-1.0059e-01,  0.0000e+00, -4.7445e-02,  ...,  2.0133e-01,
         -2.9134e-02,  0.0000e+00],
        [-6.6040e-02, -2.9215e-02, -3.2926e-01,  ..., -2.6774e-01,
         -3.7793e-01, -1.6748e-01],
        [ 3.0884e-02,  0.0000e+00, -1.3949e-03,  ...,  1.1182e-01,
         -1.0221e-01, -2.2786e-01]], device='cuda:0')

The Colab Notebook is here: https://colab.research.google.com/drive/1uK-Wna2Uo9JRx7RqCIrl_W_yVt2tDiBD#offline=true&sandboxMode=true

Does anyone have any suggestions for how to debug this or have come across this sort of issue before? I’ve seen some discussions in the Pytorch forum that this could be related to threading but the dimensions are no different from a single pass AWD-LSTM and there are no issues with that. I’ve checked the memory usage and the model uses less than 10% of the available memory.

Any suggestions would be greatly appreciated.