`RNNLearner.get_preds(DatasetType.Train, ordered=True)` does not work for `TextClasDataBunch`

A detailed example and a workaround (probably not a clean solution) in this notebook:

It seems to be caused by

  1. In TextClasDataBunch.create(), drop_last is set to True for train_dl;
  2. SortishSampler generates randomly ordered indices that can’t be recovered.

A workaround in the above notebook set drop_last=False and fixed the random seed to make it work. But I have some questions:

  1. Why is drop_last set to True for train_dl in TextClasDataBunch.create()? Is it necessary?
  2. Is there possibly a proper way to make this work?

Thanks.

Yes, we changed the default drop_last to True for the training set because small batches (especially batches of size 1) make BatchNorm layer bug (that is the recommendation from pytorch). Guessing it’s making ordered=True bug yes.
A workaround it to use fix_dl for getting the predictions on the training set (which is the same as train_dl minus transforms (in vision) and with shuffle=False, drop_last=False).

2 Likes

Thank you. Using get_preds(DatasetType.Fix, ordered=True) solved the problem.

2 Likes