How does the function fit one-hot encode the labels?

(I guess I can answer my own question by staring at the code long enough, but I was wondering if anybody already has the answer.)

Note. I am following Introduction to machine learning for coders and using fast v0.7.

In lesson4-mnist_sgd.ipynb, the lesson covers the simplest neural network, net, with one layer and 10 outputs. The model data contains images and their labels, which are one-dimensional, representing integers between 0 and 9. Then we call the function fit, which takes in net and md.

Question. Where (in the code) does the function fit make a connection between the integers in the labels and the 10 values at the output of the net?

It’s all in the loss function in PyTorch. The one use (NLL loss I believe) doesn’t require one-hot encoded targets, but indices.

1 Like