Thank you so very much Kornel, I greatly appreciate the advice and assistance with this. Is this what you are meaning?
def your_custom_loss(out, label):
raw_logits, concat_logits, part_logits, _, top_n_prob = out
creterion = torch.nn.CrossEntropyLoss()
part_loss = list_loss(part_logits.view(4 * 6, -1), label.unsqueeze(1).repeat(1, 6).view(-1)).view(4, 6)
raw_loss = creterion(raw_logits, label)
concat_loss = creterion(concat_logits, label)
rank_loss = ranking_loss(top_n_prob, part_loss)
partcls_loss = creterion(part_logits.view(4 * 6, -1),
label.unsqueeze(1).repeat(1, 6).view(-1))
total_loss = rank_loss + raw_loss + concat_loss + partcls_loss
total_loss = torch.FloatTensor(total_loss)
total_loss.type() == torch.FloatTensor
total_loss.size() == torch.Size([1])
return total_loss.squeeze(0)
Edit:
It seems there’s two things going on in the models training:
raw_optimizer.zero_grad()
part_optimizer.zero_grad()
concat_optimizer.zero_grad()
partcls_optimizer.zero_grad()
raw_logits, concat_logits, part_logits, _, top_n_prob = net(img)
part_loss = model.list_loss(part_logits.view(batch_size * PROPOSAL_NUM, -1),
label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)).view(batch_size, PROPOSAL_NUM)
raw_loss = creterion(raw_logits, label)
concat_loss = creterion(concat_logits, label)
rank_loss = model.ranking_loss(top_n_prob, part_loss)
partcls_loss = creterion(part_logits.view(batch_size * PROPOSAL_NUM, -1),
label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1))
total_loss = raw_loss + rank_loss + concat_loss + partcls_loss
total_loss.backward()
raw_optimizer.step()
part_optimizer.step()
concat_optimizer.step()
partcls_optimizer.step()
First there is this. Then:
for i, data in enumerate(trainloader):
with torch.no_grad():
img, label = data[0].cuda(), data[1].cuda()
batch_size = img.size(0)
_, concat_logits, _, _, _ = net(img)
# calculate loss
concat_loss = creterion(concat_logits, label)
# calculate accuracy
_, concat_predict = torch.max(concat_logits, 1)
total += batch_size
train_correct += torch.sum(concat_predict.data == label.data)
train_loss += concat_loss.item() * batch_size
progress_bar(i, len(trainloader), 'eval train set')
train_acc = float(train_correct) / total
train_loss = train_loss / total
The first is the partial unsupervised learning this model has. The second is the actual accuracy.
Editx2, the first is what the loss_fn should be, the second the metric