Unet/ResNet doing multiclass image segmentation - I only have 2 classes right now, may add more later. I’ve trained the model elsewhere and saved it to disk. Here I’m loading it, and trying to do predictions on a different dataset.
class CustomTransform(DisplayedTransform):
def __init__(self, resize_dim):
self.resize_dim = resize_dim
def encodes(self, x):
return x.resize(size=self.resize_dim)
src_datablock = DataBlock(
blocks=(ImageBlock, MaskBlock),
getters=[ColReader("image"), ColReader("image")],
item_tfms=[CustomTransform(resize_dim=(input_image_size, input_image_size))],
)
src_dataloader = src_datablock.dataloaders(src_df, bs=block_size)
src_learner = unet_learner(
src_dataloader,
cnn_arch,
n_out=256,
path=code_path,
model_dir=model_dir,
loss_func=CrossEntropyLossFlat(axis=1),
)
src_learner.load(model_name)
I could do predictions one by one, which is slow but works well otherwise:
for i in tqdm(range(src_df.shape[0])):
output = src_learner.predict(src_df.loc[i, "image"], with_input=True)
predicted_mask = PILImage.create(output[1])
I could do batch predictions with fastai:
pred_dl = src_learner.dls.test_dl(src_df["image"])
output = src_learner.get_preds(dl=pred_dl, with_input=True)
for i in tqdm(range(output[1].shape[0])):
predicted_mask = np.array(output[1][i].argmax(dim=0)).astype("uint8")
display(Image.fromarray(predicted_mask))
Or I could try to use plain PyTorch for predictions:
pred_dl = src_learner.dls.test_dl(src_df["image"], bs=block_size)
for i, data in tqdm(enumerate(pred_dl, 0)):
output = src_learner.model(data[0])
for i in range(block_size):
predicted_mask = output[i]
predicted_mask = np.array(output[i].argmax(dim=0).cpu()).astype("uint8")
display(Image.fromarray(predicted_mask))
The fastai predict()
and get_preds()
agree on the predictions - the predicted classes are identical down to pixel level.
The plain PyTorch predicted masks are somewhat different. They clearly predict on the same images, so the contours are quite close, but the class values are rather different in terms of pixel count from the fastai methods. The contours are close. The classes may differ quite a lot.
The model weights are the same.
The dataloader is basically the same between fastai get_preds()
and plain PyTorch.
I apply the same argmax()
for fastai get_preds()
and for plain PyTorch.
Another thing I’ve noticed is that the inputs returned by the fastai predict()
are basically the source images, whereas the inputs from get_preds()
and data[0]
from the plain PyTorch loader seem to be images with normalized pixel values.
Anyway, the main question is - what accounts for the differences, and how do I get all three prediction methods to agree on the predictions?