[imdb.ipynb] Probabilities greater than 1 when making predictions

After running imdb.ipynb, I save the classifier

learn.save(final_clas_file)

and then reload the classifier and test on some examples.

bptt,em_sz,nh,nl = 70,400,1150,3
vs = len(itos)
opt_fn = partial(optim.Adam, betas=(0.8, 0.99))
bs = 64

min_lbl = trn_labels.min()
trn_labels -= min_lbl
val_labels -= min_lbl
c=int(trn_labels.max())+1

val_ds = TextDataset(val_clas, val_labels)
val_dl = DataLoader(val_ds, bs, transpose=True, num_workers=1, pad_idx=1, sampler=None)
md = ModelData(PATH, val_dl, None)

dps = np.array([0.4, 0.5, 0.05, 0.3, 0.1])
dps = np.array([0.4,0.5,0.05,0.3,0.4])*0.5

m = get_rnn_classifier(bptt, 2070, c, vs, emb_sz=em_sz, n_hid=nh, n_layers=nl, pad_token=1,
layers=[em_sz
3, 50, c], drops=[dps[4], 0.1],
dropouti=dps[0], wdrop=dps[1], dropoute=dps[2], dropouth=dps[3])

opt_fn = partial(optim.Adam, betas=(0.7, 0.99))
learn = RNN_Learner(md, TextModel(to_gpu(m)), opt_fn=opt_fn)
learn.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
learn.clip=25.
learn.metrics = [accuracy]

learn.load(final_clas_file)
learn.data.test_dl = val_dl

log_preds = learn.predict(is_test=True)
print(log_preds)

[[ -4.18671 3.82241]
[ -3.06313 2.74281]
[ -4.70322 4.35709]
[ 5.09323 -5.18428]
[ 15.44124 -16.15837]
[ -2.26712 1.99009]
[ 1.23772 -1.40407]
[ -3.85265 3.52952]
[ -3.32406 3.0189 ]
[ -7.05887 6.56664]]

preds = np.argmax(log_preds, axis=1) # from log probabilities to 0 or 1
probs = np.exp(log_preds[:,1])

print(preds)
[1 1 1 0 0 1 0 1 1 1]

print(probs)
[ 45.71418 15.53057 78.02974 0.0056 0. 7.31617 0.2456 34.10776 20.4687 710.9803 ]

I don’t know why I get probability values that are greater than 1. Anyone knows the reason?

Your last output in the classifier aren’t log_preds: the log_softmax is done inside the loss function, so they are the last activations before the softmax layer.
To get probabilities, you should apply the softmax to your predictions:

probs = np.exp(log_preds) / np.exp(log_preds).sum(1)[:,None]
3 Likes

I got it. Thanks a lot, @sgugger!