`SortishSampler` pitfalls

Wondering if there’s an obvious explanation for this – I have a simple model:

nn.Sequential(*[
  nn.Embedding(50000, 128, padding_idx=0),
  nn.ReLU(),
  nn.Linear(128, 10000)
])

Think of this like predicting a hashtag from a tweet – your inputs are BOW featurizations of the tweets, and your output classes is the (large) number of hashtags.

If I train w/ the fastai.text.SortishSampler, I get significantly worse result than if I train w/ a random ordering:

# SortishSampler
{"epoch": 0, "p_at_01": 0.17496875, "p_at_05": 0.13944375, "p_at_10": 0.11604374999999999, "elapsed": 12.517449378967285}
{"epoch": 1, "p_at_01": 0.1878125, "p_at_05": 0.14675, "p_at_10": 0.11996250000000001, "elapsed": 25.40839433670044}
{"epoch": 2, "p_at_01": 0.10959375, "p_at_05": 0.070675, "p_at_10": 0.05009374999999999, "elapsed": 38.185012102127075}
{"epoch": 3, "p_at_01": 0.25040625, "p_at_05": 0.19470625000000005, "p_at_10": 0.16205937499999998, "elapsed": 50.31228280067444}
{"epoch": 4, "p_at_01": 0.2653125, "p_at_05": 0.20280000000000004, "p_at_10": 0.168853125, "elapsed": 61.01092886924744}
{"epoch": 5, "p_at_01": 0.29490625, "p_at_05": 0.22359375, "p_at_10": 0.18587500000000004, "elapsed": 72.7346978187561}
{"epoch": 6, "p_at_01": 0.27859375, "p_at_05": 0.21149375, "p_at_10": 0.17360625000000005, "elapsed": 84.20354294776917}
{"epoch": 7, "p_at_01": 0.31578125, "p_at_05": 0.23923125, "p_at_10": 0.196478125, "elapsed": 96.55597829818726}
{"epoch": 8, "p_at_01": 0.3333125, "p_at_05": 0.2518625, "p_at_10": 0.20785312500000003, "elapsed": 108.56037592887878}
{"epoch": 9, "p_at_01": 0.32671875, "p_at_05": 0.24431875, "p_at_10": 0.19895, "elapsed": 120.693279504776}

# random order
{"epoch": 0, "p_at_01": 0.22846875, "p_at_05": 0.18122500000000002, "p_at_10": 0.151646875, "elapsed": 13.475401163101196}
{"epoch": 1, "p_at_01": 0.28275, "p_at_05": 0.2161375, "p_at_10": 0.17963125, "elapsed": 25.572567224502563}
{"epoch": 2, "p_at_01": 0.32896875, "p_at_05": 0.25131875, "p_at_10": 0.20738750000000006, "elapsed": 37.94573616981506}
{"epoch": 3, "p_at_01": 0.34846875, "p_at_05": 0.26486875, "p_at_10": 0.215721875, "elapsed": 50.27561402320862}
{"epoch": 4, "p_at_01": 0.3824375, "p_at_05": 0.28755625, "p_at_10": 0.23284375, "elapsed": 63.51650404930115}
{"epoch": 5, "p_at_01": 0.39121875, "p_at_05": 0.29014375000000003, "p_at_10": 0.2353, "elapsed": 77.00169587135315}
{"epoch": 6, "p_at_01": 0.41134375, "p_at_05": 0.30732500000000007, "p_at_10": 0.247771875, "elapsed": 89.96660947799683}
{"epoch": 7, "p_at_01": 0.4099375, "p_at_05": 0.30201874999999995, "p_at_10": 0.24224062500000001, "elapsed": 102.85762786865234}
{"epoch": 8, "p_at_01": 0.4213125, "p_at_05": 0.31561875, "p_at_10": 0.25602500000000006, "elapsed": 115.78191900253296}
{"epoch": 9, "p_at_01": 0.43459375, "p_at_05": 0.32563125, "p_at_10": 0.263815625, "elapsed": 128.54980063438416}

where p_at_k is the precision of the top k predictions. (Top-1 accuracy isn’t super useful because the number of classes is so large). The convergence of the random sampler is

a) much faster (3 epochs to p_at_01=0.3 vs 7 for SortishSampler)
b) much smoother (monotonically increasing, whereas SortishSampler bounces around)

Anyone have any thoughts on why this would be? SortishSampler is “less random”, but I’m surprised it actually makes this big of a difference.

EDIT: Also… I wonder how much of a difference this is making for input into RNNs (eg in ULMFit finetuning)

This is really interesting, and has big implications if correct. Have you tried the same on the language modelling example from the course? Also curious what the runtime differences are between the two models?

I’ve been looking into it as well because I’m not convinced that it’s random enough in terms of the batches it creates. The way it groups has a limited number of permutations. I wonder if there’s a way to probabalistically apply sortish some of the time and randomize the rest to get some speedup.

Very interesting findings though.

What was the idea behind implementing a separate SortishSampler anyway?

@aayushy SortishSampler speeds up the model by avoiding having lots of zero padding in the input.

1 Like