I managed to get it working as follows:
import torchvision
model=torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False,num_classes=2)
model.train()
class GetResult(Callback):
def after_pred(self):
self.learn.pred = self.pred["out"]
learn = Learner(dls=dls, model=model, metrics=[Dice(),JaccardCoeff()],wd=1e-2, cbs=GetResult()).to_fp16()