I have the below code to train a model and save it as model_7.pth using fastai.
from fastai import *
from wwf.utils import state_versions
state_versions(['fastai', 'fastcore', 'wwf'])
from fastai.vision.all import *
path = Path('data')
def n_codes(fnames, is_partial=True):
"Gather the codes from a list of `fnames`"
vals = set()
if is_partial:
random.shuffle(fnames)
fnames = fnames[:10]
for fname in fnames:
msk = np.array(PILMask.create(fname))
for val in np.unique(msk):
if val not in vals:
vals.add(val)
vals = list(vals)
p2c = dict()
for i,val in enumerate(vals):
p2c[i] = vals[i]
return p2c
p2c = n_codes(lbl_names)
def get_msk(fn, pix2class):
"Grab a mask from a `filename` and adjust the pixels based on `pix2class`"
#fn = path/'GT_png'/f'{fn.stem}_mask.png'
fn = path/f'trainval/train/tile/label/{fn.name}'
msk = np.array(PILMask.create(fn))
mx = np.max(msk)
for i, val in enumerate(p2c):
msk[msk==p2c[i]] = val
return PILMask.create(msk)
codes = ['Background', 'car']
get_y = lambda o: get_msk(o, p2c)
binary = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=get_y,
item_tfms=Resize(224),
batch_tfms=[Normalize.from_stats(*imagenet_stats)])
dls = binary.dataloaders(path/'trainval/train/tile/label', bs=4, num_workers=0)#num_workers=0 to make it work in windows
dls.show_batch(cmap='Blues', vmin=0, vmax=1)
learn = unet_learner(dls, resnet34)
learn.load(r'model_7')
learn.fit(n_epoch=8, cbs=SaveModelCallback(every_epoch=True))
Now I have a test image folder called TEST_FOLDER.
I tried the below code.
TEST_FOLDER = r"test\tile\image"
files = get_image_files(TEST_FOLDER)[:1]
test_dl = learn.dls.test_dl(files)
preds = learn.get_preds(dl=test_dl)
but the output preds is of shape
preds[0].shape
Out[43]: torch.Size([1, 2, 224, 224])
I use this code to see the expected output prediction
p = preds[0][0]
plt.plot(p[0])
plt.savefig('myfig_p0')
but i get a image with 4 subplots and not sure which is the predicted mask.
Also the bottom left image doesnt correspond to the input image.
Question
- How do I make prediction of the segment mask for each test image?
- How to get the pixelwise probability for each test image?