Implementing Dice Lose
Solution 1:
You need to convert y_true
to 1-hot representation in order to apply per-class dice loss.
It seems like you have tf.one_hot
function that does it for you.
Once you have y_true
in the same shape as y_pred
, you can use your code to compute the dice score for each class separately, and then combine the scores of all classes to get the final scalar loss.