Try: learn.loss_func = MyCrossEntropyFlat(axis=1)
, thats the channel that indicates the labels.
3 Likes
Try: learn.loss_func = MyCrossEntropyFlat(axis=1)
, thats the channel that indicates the labels.