WHy do it get the error in subject from metric which i added. this is thrown during validation iterations.
def pos_multi(logit, truth, threshold=0.5):
batch_size,num_class, H,W = logit.shape
with torch.no_grad():
logit = logit.view(batch_size,num_class,-1)
truth = truth.view(batch_size,-1)
probability = torch.softmax(logit,1)
p = torch.max(probability, 1)[1]
t = truth
correct = (p==t.long())
index0 = t==0
index1 = t==1
index2 = t==2
index3 = t==3
index4 = t==4
num_neg = index0.sum()
num_pos1 = index1.sum()
num_pos2 = index2.sum()
num_pos3 = index3.sum()
num_pos4 = index4.sum()
#print(num_pos1,num_pos2,num_pos3,num_pos4)
actual_pos=(num_pos1+num_pos2+num_pos3+num_pos4).float()
total_correct=(correct[index1].sum()+correct[index2].sum()+correct[index3].sum()+correct[index4].sum()).float()
net_pos=(total_correct+0.0)/(actual_pos+0.0)
return net_pos#tn,tp, num_neg,num_pos