Why using `np.argmax` for getting predictions?

Question about why using np.argmax.

I think this is a little bit ambiguous.
Jeremy said this is predictions. but didn’t mention why would we need np.argmax here. Or maybe I misunderstand something? (around 25:55, Jeremy just said we can retrieve it through np.argmax. )

preds = np.argmax(log_preds, axis=1)  # from log probabilities to 0 or 1
probs = np.exp(log_preds[:,1])        # pr(dog)

Using np.exp for transform a log_preds is relatively straight-forward.

1 Like

log_preds[:10] returns to you the following numpy ndarray (tensor), the first 10 predictions in log scale:

#        0         1
array([[ -0.00002, -11.07446],   # 0
       [ -0.00138,  -6.58385],   # 1
       [ -0.00083,  -7.09025],   # 2
       [ -0.00029,  -8.13645],   # 3
       [ -0.00035,  -7.9663 ],   # 4
       [ -0.00029,  -8.15125],   # 5
       [ -0.00002, -10.82139],   # 6
       [ -0.00003, -10.33846],   # 7
       [ -5.00323,  -0.73731],   # 8
       [ -0.0001 ,  -9.21326]],  # 9
   dtype=float32)

np.argmax(log_preds, axis=1)

By adding the axis argument, numpy looks at the rows and columns individually.
axis=1 means that the operation is performed across the rows of log_preds.

That means np.argmax(log_preds, axis=1) returns [0, 0, 0, 0, 0, 0, 0, 0, 1, 0] because log_preds has 10 rows. The index of the maximum value in the first row is 0, the index of the maximum value of the ninth row is 1. This is what it means in the code comments ‘from log probabilities to 0 or 1’.

# from here we know that 'cats' is label 0 and 'dogs' is label 1.
data.classes

As for what the 0 or 1 means? Each row in the log_preds tensor is the predictions (log scale) for a sample (one image). For row 1, the predicted label is 0, so it’s a cat. For row ninth, the predicted label is 1, so it’s a dog.

Using np.exp for transform a log_preds is relatively straight-forward.

Yeah, log_preds returns to you the second column in the numpy ndarray which is the log predictions it’s a dog. And np.exp turns the log predictions to probabilities:

array([-11.07446,  -6.58385,  -7.09025,  -8.13645,  -7.9663 ,  -8.15125,
   -10.82139, -10.33846,  -0.73731,  -9.21326])
10 Likes