Regressing eyes to floats: RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'other'


(see jupyter nb above)

I’m trying to train a model to recognize the pupil-limbus ratio PLR. It requires eye images to be regressed to a float.

What do you think is causing this error? Wish I understood tensors and long type.

from fastai.vision import *
from fastai.metrics import error_rate

path = Path(‘Database’)
fnames = get_image_files(path)

def get_float_labels(y):
return (float(y.parts[-1].split(’_’)[1]))

np.random.seed(42)
bs = 40
bs = 40
tfms = get_transforms(do_flip=False)
fnames = get_image_files(path)
fnames

data = (ImageItemList(fnames)
.random_split_by_pct(0.3)
.label_from_func(get_float_labels)
.transform(tfms, size=224)
.databunch())

data.bs = bs

learn = create_cnn(data, models.resnet50, metrics=[mean_squared_error])

learn.fit_one_cycle(4)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 ‘other’

Hi,

I think this error is from the metrics part.
Error_rate call accuracy which is =

def accuracy(input:Tensor, targs:Tensor)->Rank0Tensor:
    "Compute accuracy with `targs` when `input` is bs * n_classes."
    n = targs.shape[0]
    input = input.argmax(dim=-1).view(n,-1)
    targs = targs.view(n,-1)
    return (input==targs).float().mean() 

try to create your own accuracy function:

def my_accuracy(input:Tensor, targs:Tensor)->Rank0Tensor:
    "Compute accuracy with `targs` when `input` is bs * n_classes."
    n = targs.shape[0]
    input = input.argmax(dim=-1).view(n,-1)
    targs = targs.view(n,-1)
    return (input==targs).float().mean()#Error from here. Try to correct the conversion for your case

def my_error_rate(input:Tensor, targs:Tensor)->Rank0Tensor:
    "1 - `accuracy`"
    return 1 - my_accuracy(input, targs)

learn = create_cnn(data, models.resnet50,metrics=my_error_rate)

You can check this:

1 Like

Thanks I meant to run this line instead of the other (wrong) learn.

learn = create_cnn(data, models.resnet50, metrics=[mean_squared_error])

Now it runs - check out the nb viewer