Or you could just clamp the values inside the mask itself. First I’d check what a value inside of one of those masks are (if it’s actually argmaxed or not, I don’t believe it is). Then I’d do something like using np.where and modify the raw values themselves inside of probs (I can provide a code snippet if you don’t know how to do that/can’t figure it out, though it will take me a while to get that)
I noodled around with this for quite some time and couldn’t figure out how to get it to work without writing to a new mask. This is because what is low probability or zero for one class could be postive for the next class.
So even if I did something like this to modify probs:
I still have 4 channels of classes, so I think I have to write to a new mask to avoid accidentally nuking a prediction for another class in another channel. Hope that makes sense.
If you have the extra brain cells and know a better way, I’d love to hear it. I’m admittedly not the best at vectorizing things.