How to interpret IMDB sentiment predictions?

Hey, I didn’t change the model structure - I just modified this line when defining TextData.from_splits from

trn_iter,val_iter = torchtext.data.BucketIterator.splits(splits, batch_size=bs)

to

trn_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits(splits, batch_sizes=(bs, bs, 1))

You can see from the source code here that torchtext.data.BucketIterator.splits actually takes in a batch_sizes tuple argument that defines batch sizes for different datasets.

Yes, if no shuffling is involved, torchtext sorts data by the word length of the text object (because we define sort_key as def sort_key(ex): return len(ex.text), as in cell 115 of this notebook), and in the case of ties, it breaks them by preserving the original order. So I would sort my data by those two factors too. It does that because it tries to group texts with similar lengths together in a batch to feed into the model.

Oh sweet! Yeah I had also modified the source to return an additional iterator for test, otherwise it would break.

We should create a PR for these changes, you think?

Yup. I was just looking at this.

Was trying to set it = 1 but keeps throwing an exception.

Absolutely!

Okay, do you want to do it? It’s 4am here for me …

At least one other person is also getting this error

I don’t think I’ll be able to get to it today though (have a bunch of errands to run). Some time next week, realistically.

Okay, I’ll try to work on it later too. Or maybe someone else will get to it.

I still didn’t get the ordering to work. I trust that your method works as you say. Do you see anything wrong with my code?

full gist


Summary

nlp.py

    trn_iter,val_iter,test_iter = torchtext.data.BucketIterator.splits(splits, batch_sizes=(bs, bs, 1))

in my notebook:

md2 = TextData.from_splits(PATH, splits, 1) #setting all bs to 1, just for making predictions
m3 = md2.get_model(opt_fn, 1500, bptt, emb_sz=em_sz, n_hid=nh, n_layers=nl, 
           dropouti=dropouti, dropout=dropout, wdrop=wdrop,
                       dropoute=dropoute, dropouth=dropouth)
m3.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
m3.load_encoder(f'h2_adam1_enc')
m3.load_cycle('h1',3)
val_preds,y = m3.predict_with_targs()

res = np.argmax(val_preds,axis=1)

Did you try to sort val_df by text length as discussed above? I might have missed it in your notebook but this is what I did (my dataset is called txt_test and my text column is text):

# Sort by len
# Because that's how torchtext would sort it,
# Hence need to do the same in order to match its results
txt_test['text_toks'] = txt_test['text'].apply(spacy_tok)
txt_test['text_len'] = txt_test['text_toks'].str.len()
txt_test['index'] = txt_test.index  # Note this is assuming that the data is already sorted by index; if that's not true, use `.iloc` instead

txt_test.sort_values(by=['text_len', 'index'], inplace=True)
txt_test.reset_index(drop=True, inplace=True)

Btw, here is my full notebook, which I hope is right.

1 Like

FYI as I’m sure you’ve noticed, I haven’t used a test set with this class before - sorry about the shuffling thing! I’m working on tomorrow’s class at the moment so won’t be able to debug right away, but if you want to do so, try looking at how torchtext is handling this. I’m not sure if the issue is in torchtext, or just how I’m calling it.

Both torchtext and fastai are pretty simple code to read - hopefully it’ll be reasonably clear what’s going on. Let me know if I can help clarify anything!

@KevinB was able to get a submission into the Happiness competition, so I think he has something that works, and I don’t think it’s as complicated as we are making it. Perhaps he can enlighten us when he has a chance.

I think the issue is in torchtext.

If you look at the source code for BucketIterator here, you can see that it always sorts (even if you set sort=False. That simply, shifts the sorting to happen in the batches.

So all I did is use what Jeremy did in his lesson 4 notebook to predict what the sentence will be. I set my batchsize to 1 and I pulled the text from the CSV file directly. Then I just looped through those one at a time and tied them to a file. Then I just chose the top prediction and converted it from the index to the actual word. Is there any specific code/questions you are wondering about?

Can you share the code for how you loop through examples to do prediction one by one?

m = m3.model 
m[0].bs = 1
for i in range(tst.values[:,1].shape[0]):
    ss = tst["Description"][i] #Actual text review
    s = [spacy_tok(ss)]
    t = TEXT.numericalize(s)
   
    m.eval()
    m.reset()
    res,*_ = m(t)
    prediction = PH_LABEL.vocab.itos[to_np(torch.topk(res[-1], 1)[1])[0]]
6 Likes

Great thanks! I got it working.

I had previously missed the first line where you have to do m = m3.model before calling m(t). I guess that is easy to miss since m3 is the result of a call to get_model().

Anyway, good on you for getting this to work. It must mean you understand the code quite well.

This is much easier than trying to modify the TextData class! I also like that training/validation is completely decoupled with testing this way.

1 Like

I think I’m getting close to solving this by replacing BucketIterator with Iterator.

I’m getting predictions from my test dataset, BUT there are less predictions than there are examples in the test dataset for some reason (8 less to be specific). See below:

Any ideas about why I don’t have a prediction for every example?

BucketIterator is important for sorting by length. Otherwise it can be very inefficient if you have stuff of different lengths (which is v common).

You’re missing some predictions because it always uses a integer multiple of batch size. So I guess bs=1 is important for that reason too!

Hmmm … ok.

So basically, we need BucketIterator for performance reasons … and BucketIterator is going to do some kind of sorting (either over the entire dataset or within each minibatch). So for validation and test datasets, we’re going to have to change our bs=1 in order for the results to come back in order.

Does that sound about right?