If anyone just wants code:
def accuracy_1(input:Tensor, targs:Tensor)->Rank0Tensor:
“Compute accuracy with targs
when input
is bs * n_classes.”
targs = targs.view(-1).long()
n = targs.shape[0]
input = input.argmax(dim=-1).view(n,-1)
targs = targs.view(n,-1)
return (input==targs).float().mean()
So use metrics=accuracy_1 instead of accuracy