Seq2Seq notebook 7 fastai nlp

Is there a good explanation of what is going on in the following code taken from https://github.com/fastai/course-nlp/blob/master/7-seq2seq-translation.ipynb

def seq2seq_collate(samples, pad_idx=1, pad_first=True, backwards=False):
"Function that collect samples and adds padding. Flips token order if needed"
samples = to_data(samples)
max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples])
res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx
res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx

if backwards: pad_first = not pad_first
for i,s in enumerate(samples):
    if pad_first: 
        res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])
    else:         
        res_x[i,:len(s[0]):],res_y[i,:len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])
if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1)
return res_x,res_y

and here:

class Seq2SeqDataBunch(TextDataBunch):
"Create a `TextDataBunch` suitable for training an RNN classifier."
@classmethod
def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,
           dl_tfms=None, pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch:
    "Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`"
    datasets = cls._init_ds(train_ds, valid_ds, test_ds)
    val_bs = ifnone(val_bs, bs)
    collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)
    train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)
    train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)
    dataloaders = [train_dl]
    for ds in datasets[1:]:
        lengths = [len(t) for t in ds.x.items]
        sampler = SortSampler(ds.x, key=lengths.__getitem__)
        dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))
    return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)

One thing that is particularly confusing me is that SortishSampler is being used after the padding has been done in the seq2seq_collate. I thought that if we sort first, and do the padding afterwards, then the amount of required padding will be minimal. If that is not the purpose of sorting, then can someone please explain why are we sorting?

Also, here collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards) what does partial do ?

In seq2seq_collate Why would we want the padding to be in the front?

Also, any other inputs to clarify the above code pieces would be helpful. Thank you!

The partial function doesn’t call seq2seq_collate but creates a new function from seq2seq_collate with inputs pad_idx, pad_first, backwards.

So the sequence is the following:

  1. Create your separate dataloaders (where SortishSampler is applied to train_dl)
  2. Return a DataBunch with all your DataLoaders

Because seq2seq_collate is only applied at step 2, you first sort your examples (with some degree of randomness) by length and only then apply the padding. So your intuition is right, that sorting is used in order to avoid unnecessary padding by batching up sequences of similar length.

Not sure why exactly padding is applied in the front in this case. My guess is that it wouldn’t make too much of a difference to apply padding after each sequence.