And then: pred_mask.apply_tfms(tfms[0],size=(1, img.size[0], img.size[1]),resize_method=ResizeMethod.SQUISH)
Ref:
image segment data should be stored internally as floats and the resampling methods can be found here
even though ImageSegment.data returns a long tensor, it stores a float tensor in px , so I need to call ImageSegment on a float tensor instead of a long .