Solution 1:

You need to convert y_true to 1-hot representation in order to apply per-class dice loss. It seems like you have tf.one_hot function that does it for you.

Once you have y_true in the same shape as y_pred, you can use your code to compute the dice score for each class separately, and then combine the scores of all classes to get the final scalar loss.