CrossEntropy grad calculation


I’m trying to do the cross entropy module from scratch (without batches for simplification) but I’m a bit lost with the grad calculation (you can find some notes about this here: for instance).

With the MNIST data, with the same layers as in the Part 2 Lessons 8 and 9 [ Linear([784, 50]), ReLU() and Linear([50, 10]) ] I calculate the loss with CrossEntropy and call backward(), but when the backward calls CrossEntropy::bwd, input.grad must be a [50000, 10] tensor and I don’t know how to do the derivative given:

  • output: just a number
  • input: a [50000, 10] tensor
  • target: a [50000] tensor

Here is my CrossEntropy module source code:

def logsumexp(x):
    m = x.max(-1)[0]
    return m + (x - m[:,None]).exp().sum(-1).log()

def log_softmax(x): 
    return x - x.logsumexp(-1, keepdim=True)

def nll(input, target): 
    return -input[range(target.shape[0]), target].mean()

class CrossEntropy(Module):
    def forward(self, input, target):
        return nll(log_softmax(input), target)

    def bwd(self, output, input, target):
        aux = torch.argmax(input, dim=1) - target
        input.grad = ???

In aux we have the difference between the current output and the real output as a [50000] tensor.

Here is the fit method (without batches) from the class Model():

def fit(self, x, target, epochs, lr=0.1):
    for epoch in range(epochs):
        out = self(x, target)
        loss = self.loss(out, target)
        with torch.no_grad():
            for layer in self.layers:
                if (hasattr(layer, '_parameters')):
                    for p in layer._parameters.values():
                        p -= p.grad * lr

Could someone help me with the CrossEntropy::bwd implementation?

Finally, I got it.

I put the solution here in case anyone else has the same problem.

class CrossEntropy(Module):
    def forward(self, input, target):
        # Stable SoftMax:
        # Substracting each example maximun shifts all of elements in the vector to negative to zero, 
        # and negatives with large exponents saturate to zero rather than the infinity, 
        # avoiding overflowing and resulting in nan.        
        # input - input [N, C] -> [Max1, ..., MaxN] -> [[Max1, ..., Max1], ..., [MaxN, ..., MaxN]]
        aux = (input - input.max(1).values.unsqueeze(1).expand_as(input))

        exp_input = aux.exp()

        # exp_input [N, C] -> [SumC1, ..., SumCN] -> [[SumC1, ..., SumC1], ..., [SumCN, ..., SumCN]]
        denom = exp_input.sum(1).unsqueeze(1).expand_as(input)

        self.softMax = exp_input / denom

        self.targetOneHot = torch.nn.functional.one_hot(target)

        loss = -(self.targetOneHot * self.softMax.log()).sum() / input.size(0)

        return loss

    def bwd(self, output, input, target):
        input.grad =  (self.softMax - self.targetOneHot) / input.size(0)

In any case the torch built in cross_entropy function trains faster (better loss with less epochs)…

Any clues?