How to predict a single digit

Referring to the code at :04_mnist_basics

I’m trying to use the trained model to predict whether a single digit is 3 or 7.

Right after the line:
learn.fit(40, 0.1)

I typed:
learn.predict(train_x[0])

Getting the following error:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-382-cd610cd9f076> in <module>()
----> 1 learn.predict(train_x[0])

24 frames
/usr/local/lib/python3.6/dist-packages/PIL/Image.py in fromarray(obj, mode)
   2692         raise ValueError("Too many dimensions: %d > %d." % (ndim, ndmax))
   2693 
-> 2694     size = shape[1], shape[0]
   2695     if strides is not None:
   2696         if hasattr(obj, "tobytes"):

IndexError: tuple index out of range

Can someone explain why this is not working and what I should type to get the prediction?