Slightly unsure on the output of probabilities of semantic segmentation model.
>>> mask, decoded, probs = learn.predict(x)
torch.Size([4, 360, 480])
My first guess was these are 3 channels of image and an alpha channel, but that doesn’t make sense when the prediction mask is only one channel.
My use case is I’d like to set any pixels with low probability to background class (ie: changing pixel label to 0)
Thanks for your help!
Each channel is a class, so I presume you have 4 classes in your problem?
Ah ok. Yes I do. So if I were to set an arbitrary probability threshold (say > .5), I’d essentially need to make a new mask?
Something like this?
mask, decoded, probs = learn.predict(x)
new_mask = torch.zeros_like(mask)
for i in range(probs.shape):
new_mask[probs[i] > .5] = i
Or you could just clamp the values inside the mask itself. First I’d check what a value inside of one of those masks are (if it’s actually argmaxed or not, I don’t believe it is). Then I’d do something like using
np.where and modify the raw values themselves inside of
probs (I can provide a code snippet if you don’t know how to do that/can’t figure it out, though it will take me a while to get that)
I noodled around with this for quite some time and couldn’t figure out how to get it to work without writing to a new mask. This is because what is low probability or zero for one class could be postive for the next class.
So even if I did something like this to modify probs:
probs = torch.where(probs > .5, probs, torch.zeros_like(probs))
I still have 4 channels of classes, so I think I have to write to a new mask to avoid accidentally nuking a prediction for another class in another channel. Hope that makes sense.
If you have the extra brain cells and know a better way, I’d love to hear it. I’m admittedly not the best at vectorizing things.
Why not just iterate through them?
for i, msk in enumerate(probs):
probs[i] = torch.where(msk > 0.5, msk, torch.zeros_like(msk)
Oh, the problem isn’t changing probs. It’s changing the prediction
mask. My obstacle is
probs has 4 dimensions and
mask has only 1.
At the end of the day, I’m trying to have a
mask with any pixel with
prob lower than a threshold set to 0.
Hope that makes sense. Thanks for your help!
Mask is the result of a
torch.argmax() being done with a
dim=1. So we do this in three steps:
- Apply your where filter
- Do a
torch.softmax(dim=1) on our raw probabilities so they sum to 1
- From there do
torch.argmax(dim=1) to get our class values.
Ahhhhhhh ok! Nice. That clicked for me. Just had to change the
dim to 0 and now I am off to the races!
This was final result:
torch.where(probs >= .5, probs, torch.zeros_like(probs))\