SegmentationInterpretation issue

Hello, I’m trying to run a Segmentation Interpretation in a dynamic Unet.

However, the results don’t seem reliable, something is not working as it should.

First, following the documentation, I am not able to run the line:

interp = SegmentationInterpretation.from_learner(learn)

I need to put my results from get_preds

results = get_preds(learn.model, data.train_dl)

and the line becomes

interp = SegmentationInterpretation.from_learner(learn, results)

Then, I try to plot the top losses. However, when I try to pass the size as a Tuple, I get the following error:

top_losses, top_idxs = interp.top_losses(sizes=(380,380))

RuntimeError: shape ‘[-1, 144400]’ is invalid for input of size 665

I found a way to make it work and put sizes = 665, with 665 being the number of items in my validation set. Is this normal?

And the result is:

tensor([0.3652])

Shouldn’t this be a set of results and not a single value?

And then my confusion matrix?

I guess this is just running for one single image. What am I doing wrong and why isn’t the interpreter taking into account the whole results values?

My results looks like this:

[tensor([[[[ 1.4468e-01,  8.7335e-01,  6.1114e-01,  ..., -8.2787e-01,
            -2.5200e-01,  1.7118e+00],
           [ 7.6972e-01,  4.5180e-01,  1.2123e+00,  ..., -6.6132e-01,
            -7.5332e-01,  6.7776e-01],
           [ 4.4170e-01,  3.2671e-01,  2.4824e+00,  ...,  2.2353e+00,
             4.7981e-01,  1.4205e+00],
           ...,
           [-1.3661e-01, -1.3256e+00,  1.3750e+00,  ...,  7.5920e-01,
             5.8397e-01,  2.1016e+00],
           [ 2.2558e-01, -2.6954e-01,  2.8010e+00,  ...,  1.3548e+00,
             1.3598e-02,  1.3101e+00],
           [ 1.1904e+00,  8.7612e-01,  2.5442e+00,  ...,  9.3389e-01,
             4.9314e-01,  5.1853e-01]],
 
          [[-6.7372e+00, -9.0618e+00, -9.9961e+00,  ..., -1.2145e+01,
            -1.0189e+01, -7.1810e+00],
           [-1.0522e+01, -1.4402e+01, -1.6274e+01,  ..., -1.7912e+01,
            -1.4213e+01, -9.0940e+00],
           [-1.2539e+01, -1.7398e+01, -2.0873e+01,  ..., -2.3158e+01,
            -1.7605e+01, -9.9877e+00],
           ...,
           [-1.3156e+01, -1.8105e+01, -1.9474e+01,  ..., -1.5579e+01,
            -1.2477e+01, -8.2640e+00],
           [-1.2645e+01, -1.7074e+01, -1.9606e+01,  ..., -1.2430e+01,
            -1.0209e+01, -6.7426e+00],
           [-9.2591e+00, -1.0878e+01, -1.4676e+01,  ..., -8.1643e+00,
            -7.6539e+00, -5.9465e+00]]],
 
 
         [[[-1.0801e-02, -3.6409e-03, -1.0007e-01,  ..., -6.0957e-01,
            -2.0075e-01,  1.9249e+00],
           [ 8.0788e-01,  3.0190e-02,  5.7790e-01,  ..., -4.8901e-01,
            -6.6968e-01,  8.5799e-01],
           [ 6.2487e-01,  1.4730e-01,  1.9167e+00,  ...,  2.1901e+00,
             5.2001e-01,  1.3926e+00],
           ...,
           [-1.4601e-01, -1.1229e+00,  1.2673e+00,  ...,  1.8509e+00,
             1.3093e+00,  1.7931e+00],
           [ 3.5803e-01, -1.8547e-01,  2.7042e+00,  ...,  1.6967e+00,
             7.3590e-01,  1.0654e+00],
           [ 1.1218e+00,  8.6573e-01,  2.6743e+00,  ...,  6.8270e-01,
             4.4729e-01,  2.9603e-01]],
 
          [[-4.3485e+00, -6.1270e+00, -6.7421e+00,  ..., -1.2293e+01,
            -1.0214e+01, -7.1567e+00],
           [-7.4956e+00, -9.7985e+00, -1.1307e+01,  ..., -1.7831e+01,
            -1.4210e+01, -9.2019e+00],
           [-8.9132e+00, -1.1842e+01, -1.4291e+01,  ..., -2.2266e+01,
            -1.7019e+01, -9.8177e+00],
           ...,
           [-1.4164e+01, -1.9279e+01, -1.9562e+01,  ..., -1.1419e+01,
            -8.6349e+00, -4.9107e+00],
           [-1.3463e+01, -1.7723e+01, -1.9401e+01,  ..., -8.9227e+00,
            -7.2754e+00, -4.2319e+00],
           [-9.7827e+00, -1.1154e+01, -1.4544e+01,  ..., -5.8314e+00,
            -5.4231e+00, -3.9991e+00]]],
 
 
         [[[ 1.7133e-02,  6.3035e-01,  5.2401e-01,  ..., -1.0465e+00,
            -4.2758e-01,  1.6402e+00],
           [ 6.2300e-01,  3.2023e-01,  1.0300e+00,  ..., -9.3723e-01,
            -9.6673e-01,  4.8233e-01],
           [ 2.9155e-01,  9.3576e-02,  2.0477e+00,  ...,  1.7876e+00,
             2.7737e-01,  1.2000e+00],
           ...,
           [-2.5892e-01, -1.5989e+00,  7.2832e-01,  ...,  5.2588e-01,
             2.2872e-01,  1.8319e+00],
           [ 1.4451e-01, -7.8453e-01,  2.2239e+00,  ...,  1.0246e+00,
            -2.4499e-01,  1.1318e+00],
           [ 1.2279e+00,  8.3053e-01,  2.3797e+00,  ...,  6.4187e-01,
             2.2986e-01,  2.8025e-01]],
 
          [[-6.1538e+00, -8.2692e+00, -9.1709e+00,  ..., -1.1667e+01,
            -9.9958e+00, -6.8374e+00],
           [-9.5276e+00, -1.2810e+01, -1.4394e+01,  ..., -1.7223e+01,
            -1.3918e+01, -8.7715e+00],
           [-1.0831e+01, -1.4577e+01, -1.7651e+01,  ..., -2.2257e+01,
            -1.6910e+01, -9.4540e+00],
           ...,
           [-1.2165e+01, -1.6311e+01, -1.6999e+01,  ..., -1.4018e+01,
            -1.1190e+01, -7.4252e+00],
           [-1.1839e+01, -1.5454e+01, -1.7523e+01,  ..., -1.0939e+01,
            -8.9282e+00, -6.0599e+00],
           [-8.7769e+00, -1.0224e+01, -1.3525e+01,  ..., -7.0479e+00,
            -6.5639e+00, -5.3908e+00]]],
 
 
         ...,
 
 
         [[[ 2.1967e-02,  6.5942e-01,  4.6379e-01,  ..., -2.7174e-01,
             3.0500e-02,  1.9285e+00],
           [ 6.1972e-01,  1.9643e-01,  8.3483e-01,  ..., -2.1242e-01,
            -4.7280e-01,  8.5668e-01],
           [ 2.2958e-01, -1.1480e-01,  1.8891e+00,  ...,  2.4325e+00,
             6.3936e-01,  1.7032e+00],
           ...,
           [-2.8645e-01, -1.6853e+00,  7.1063e-01,  ...,  1.1244e+00,
             1.0763e+00,  2.2782e+00],
           [ 1.0913e-01, -7.6801e-01,  2.1725e+00,  ...,  1.7374e+00,
             2.4123e-01,  1.3267e+00],
           [ 1.1455e+00,  8.1268e-01,  2.3208e+00,  ...,  1.0990e+00,
             5.4318e-01,  3.9346e-01]],
 
          [[-5.9434e+00, -7.9282e+00, -8.4962e+00,  ..., -1.3081e+01,
            -1.0858e+01, -7.8052e+00],
           [-9.1772e+00, -1.2077e+01, -1.3385e+01,  ..., -1.9339e+01,
            -1.5332e+01, -1.0180e+01],
           [-1.0439e+01, -1.3945e+01, -1.6877e+01,  ..., -2.4809e+01,
            -1.8991e+01, -1.1534e+01],
           ...,
           [-1.1615e+01, -1.5605e+01, -1.6675e+01,  ..., -1.7378e+01,
            -1.3956e+01, -9.0995e+00],
           [-1.1260e+01, -1.4682e+01, -1.7012e+01,  ..., -1.3805e+01,
            -1.1143e+01, -7.2914e+00],
           [-8.3269e+00, -9.6404e+00, -1.3110e+01,  ..., -9.1186e+00,
            -8.1994e+00, -6.2206e+00]]],
 
 
         [[[ 1.6317e-02,  5.7076e-01,  5.6015e-01,  ..., -4.0745e-01,
            -7.5681e-02,  1.8005e+00],
           [ 6.5834e-01,  2.8126e-01,  1.0710e+00,  ..., -3.1594e-01,
            -5.3047e-01,  7.8265e-01],
           [ 2.8048e-01, -1.3096e-02,  2.0208e+00,  ...,  2.4525e+00,
             7.2436e-01,  1.6541e+00],
           ...,
           [ 9.4154e-02, -8.5906e-01,  1.9035e+00,  ...,  1.1086e+00,
             4.3990e-01,  1.8993e+00],
           [ 4.0548e-01, -6.7279e-02,  2.9828e+00,  ...,  1.4265e+00,
            -1.4418e-01,  9.9297e-01],
           [ 1.2256e+00,  8.7994e-01,  2.7080e+00,  ...,  8.4020e-01,
             2.1204e-01,  2.6449e-01]],
 
          [[-6.1371e+00, -8.1564e+00, -9.1983e+00,  ..., -1.2900e+01,
            -1.0808e+01, -7.8114e+00],
           [-9.3913e+00, -1.2761e+01, -1.4697e+01,  ..., -1.9034e+01,
            -1.5068e+01, -9.9818e+00],
           [-1.0695e+01, -1.4642e+01, -1.8090e+01,  ..., -2.4197e+01,
            -1.8492e+01, -1.1069e+01],
           ...,
           [-1.4588e+01, -2.0322e+01, -2.1545e+01,  ..., -1.7695e+01,
            -1.3373e+01, -8.2065e+00],
           [-1.3971e+01, -1.8835e+01, -2.1342e+01,  ..., -1.4134e+01,
            -1.1070e+01, -6.7222e+00],
           [-1.0279e+01, -1.1996e+01, -1.5651e+01,  ..., -9.2463e+00,
            -8.1208e+00, -5.8158e+00]]],
 
 
         [[[ 3.7075e-02,  6.8498e-01,  7.4305e-01,  ..., -8.5988e-02,
             1.2866e-01,  2.0257e+00],
           [ 7.6833e-01,  3.1818e-01,  1.2825e+00,  ...,  9.8875e-03,
            -3.2269e-01,  9.8736e-01],
           [ 4.5415e-01,  1.5833e-01,  2.2943e+00,  ...,  2.8398e+00,
             9.8516e-01,  1.9031e+00],
           ...,
           [ 5.5511e-02, -1.1339e+00,  1.4711e+00,  ...,  1.2982e+00,
             1.0123e+00,  2.3091e+00],
           [ 3.4406e-01, -2.8910e-01,  2.7385e+00,  ...,  1.7025e+00,
             1.8388e-01,  1.3248e+00],
           [ 1.3326e+00,  8.9612e-01,  2.5960e+00,  ...,  8.9384e-01,
             4.6546e-01,  4.6095e-01]],
 
          [[-6.4339e+00, -8.7184e+00, -9.6035e+00,  ..., -1.4029e+01,
            -1.1557e+01, -8.2306e+00],
           [-9.9748e+00, -1.3727e+01, -1.5678e+01,  ..., -2.0983e+01,
            -1.6389e+01, -1.0778e+01],
           [-1.1640e+01, -1.6150e+01, -1.9516e+01,  ..., -2.6595e+01,
            -2.0447e+01, -1.2364e+01],
           ...,
           [-1.3760e+01, -1.8914e+01, -2.0344e+01,  ..., -1.7563e+01,
            -1.3633e+01, -8.6760e+00],
           [-1.3337e+01, -1.7752e+01, -2.0119e+01,  ..., -1.4218e+01,
            -1.1307e+01, -7.1252e+00],
           [-9.7367e+00, -1.1333e+01, -1.4941e+01,  ..., -9.3603e+00,
            -8.4278e+00, -6.1551e+00]]]]), tensor([[[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]]],
 
 
         [[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]]],
 
 
         [[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]]],
 
 
         ...,
 
 
         [[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]]],
 
 
         [[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]]],
 
 
         [[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]]]])]

Any suggestion is highly appreciated since I’m stucked with this issue for days.

Kind regards

i met the same error with your ways,
but then i change code :
def seg_confution(learn):
#results = get_preds(learn.model, data.train_dl)
interp = SegmentationInterpretation.from_learner(learn)
#top_losses, top_idxs = interp.top_losses(sizes=(64,64))
mean_cm, single_img_cm = interp._generate_confusion()
df1 = interp._plot_intersect_cm(mean_cm, “Mean of Ratio of Intersection given True Label”)
plt.savefig(‘globel_confution1.png’)
df1.to_csv(‘data_globel1.csv’)

and my loss:
def dice_loss(inputX, target, reduction=‘mean’):

pdb.set_trace()

smooth = 1.
n = inputX.size(0)
inputX = inputX[:,1,None].sigmoid()
iflat = inputX.contiguous().view(n, -1).float()
tflat = target.view(n, -1).float()
intersection = (iflat * tflat).sum(-1)
dice =  (1 - ((2. * intersection + smooth) / ((iflat + tflat).sum(-1) +smooth)))
if reduction == 'mean':
    return dice.mean()
elif reduction=='sum':
    return dice.sum()
else:
    return dice

it’s fine