Error (pytorch?): Expected object of scalar type Long but got scalar type Float for argument #2 'other

If anyone just wants code:

def accuracy_1(input:Tensor, targs:Tensor)->Rank0Tensor:
“Compute accuracy with targs when input is bs * n_classes.”
targs = targs.view(-1).long()
n = targs.shape[0]
input = input.argmax(dim=-1).view(n,-1)
targs = targs.view(n,-1)
return (input==targs).float().mean()

So use metrics=accuracy_1 instead of accuracy

4 Likes