Intersection Over Union with Camvid data set


I am going over lesson 3 of the MOOC right now, and I am at the part where we are supposed to train U-Net on Camvid. While training the U-Net, rather than just the accuracy, I want to also see the Intersection over Union (IoU) metric for the validation error. After a bit of searching, I found that fastai’s IoU is implemented within the dice() function [1], but the function only works for a binary target. Is there an easy way to get fastai’s IoU metric to work with training Camvid, or will I have to implement the IoU metric function myself from scratch?

Also, I am a beginner in all of this, so if I have any misunderstandings, feel free to correct me. Thank you!