How to get prediction on new test image data using fastai2 model?

I have the below code to train a model and save it as model_7.pth using fastai.

from fastai import *
from wwf.utils import state_versions
state_versions(['fastai', 'fastcore', 'wwf'])

from import *
path = Path('data')

def n_codes(fnames, is_partial=True):
  "Gather the codes from a list of `fnames`"
  vals = set()
  if is_partial:
    fnames = fnames[:10]
  for fname in fnames:
    msk = np.array(PILMask.create(fname))
    for val in np.unique(msk):
      if val not in vals:
  vals = list(vals)
  p2c = dict()
  for i,val in enumerate(vals):
    p2c[i] = vals[i]
  return p2c

p2c = n_codes(lbl_names)

def get_msk(fn, pix2class):
  "Grab a mask from a `filename` and adjust the pixels based on `pix2class`"
  #fn = path/'GT_png'/f'{fn.stem}_mask.png'
  fn = path/f'trainval/train/tile/label/{}'
  msk = np.array(PILMask.create(fn))
  mx = np.max(msk)
  for i, val in enumerate(p2c):
    msk[msk==p2c[i]] = val
  return PILMask.create(msk)

codes = ['Background', 'car']

get_y = lambda o: get_msk(o, p2c)

binary = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),

dls = binary.dataloaders(path/'trainval/train/tile/label', bs=4, num_workers=0)#num_workers=0 to make it work in windows

dls.show_batch(cmap='Blues', vmin=0, vmax=1)

learn = unet_learner(dls, resnet34)
learn.load(r'model_7'), cbs=SaveModelCallback(every_epoch=True))

Now I have a test image folder called TEST_FOLDER.

I tried the below code.

TEST_FOLDER = r"test\tile\image"
files = get_image_files(TEST_FOLDER)[:1]
test_dl = learn.dls.test_dl(files)
preds = learn.get_preds(dl=test_dl)

but the output preds is of shape

Out[43]: torch.Size([1, 2, 224, 224])

I use this code to see the expected output prediction

p = preds[0][0]

but i get a image with 4 subplots and not sure which is the predicted mask.
Also the bottom left image doesnt correspond to the input image.
enter image description here


  1. How do I make prediction of the segment mask for each test image?
  2. How to get the pixelwise probability for each test image?