How to get other metrics in Tensorflow 2.0 (not only accuracy)?

Solution 1:

I am adding another answer because this is the cleanest way in order to compute these metrics correctly on your test set (as of 22nd of March 2020).

The first thing you need to do is to create a custom callback, in which you send your test data:

import tensorflow as tf
from tensorflow.keras.callbacks import Callback
from sklearn.metrics import classification_report 

class MetricsCallback(Callback):
    def __init__(self, test_data, y_true):
        # Should be the label encoding of your classes
        self.y_true = y_true
        self.test_data = test_data
        
    def on_epoch_end(self, epoch, logs=None):
        # Here we get the probabilities
        y_pred = self.model.predict(self.test_data))
        # Here we get the actual classes
        y_pred = tf.argmax(y_pred,axis=1)
        # Actual dictionary
        report_dictionary = classification_report(self.y_true, y_pred, output_dict = True)
        # Only printing the report
        print(classification_report(self.y_true,y_pred,output_dict=False)              
           

In your main, where you load your dataset and add the callbacks:

metrics_callback = MetricsCallback(test_data = my_test_data, y_true = my_y_true)
...
...
#train the model
model.fit(x_train, y_train, callbacks = [cp_callback, metrics_callback,tensorboard], epochs=5)

         

Solution 2:

Starting from TensorFlow 2.X, precision and recall are both available as built-in metrics.

Therefore, you do not need to implement them by hand. In addition to this, they were removed before in Keras 2.X versions because they were misleading --- as they were being computed in a batch-wise manner, the global(true) values of precision and recall would be actually different.

You can have a look here:https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall

Now they have a built-in accumulator, which ensures the correct calculation of those metrics.

model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy',tf.keras.metrics.Precision(),tf.keras.metrics.Recall()])