Learn.get_preds() argmax gives different result from decoded

Running a unet_learner with fastai v2.5.3 to perform segmentation, I noticed that when running learn.get_preds(with_decoded=True), the decoded results do not match the argmax of the predicted probabilities.

To demonstrate this, I made a copy of @muellerzr 's Binary Segmentation notebook. You can see the output here.

As seen in the notebook, at about 6% of pixels, torch.argmax yields a different result from the decoded values. Additionally, the decodes output looks much more like the target mask than the argmax result (see below).

The loss function is just FlattenedLoss of CrossEntropyLoss() and its decodes function is just def decodes(self, x): return x.argmax(dim=self.axis). In other words, it’s not clear to me that the decodes approach is supposed to be behaving differently from just taking the argmax of the predictions.

I suspect I’m missing something obvious, but figured I’d raise the question here.

Target mask:
image

Decodes mask:
image

Argmax mask:
image

Given the Learner.get_preds documentation, the order of items in the get_preds tuple is not what I expected. When requesting with_decoded and with_input, the tuple that is returned from get_preds is (in order):

  • [0]: truth input (raw image)
  • [1]: prediction probabilities
  • [2]: truth output (the mask used for training)
  • [3]: decoded output (the argmax of the prediction probabilities from index 1)

Therefore, based on this index lookup, the “decodes mask” from my initial question was actually the mask/truth.