When building a classifier using TensorFlow Keras, one often monitors model accuracy by specifying metrics=['accuracy'] during the compilation step:
model = tf.keras.Model(...)
model.compile(optimizer=..., loss=..., metrics=['accuracy'])
This behaves correctly whether or not the model outputs logits or class probabilities, and whether or not the model expects ground truth labels to be one-hot-encoded vectors or integer indices (i.e., integers in the interval [0, n_classes)).
This is not the case if one wants to use cross-entropy loss: each of the four combinations of the cases mentioned above requires a different loss value to be passed during the compilation step:
If the model outputs probabilities and the ground truth labels are one-hot-encoded, then loss='categorical_crossentropy' works.
If the model outputs probabilities and the ground truth labels are integer indices, then loss='sparse_categorical_crossentropy' works.
If the model outputs logits and the ground truth labels are one-hot-encoded, then loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True) works.
If the model outputs logits and the ground truth labels are integer indices, then loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) works.
It seems that just specifying loss='categorical_crossentropy' is not robust enough to handle these four cases, whereas specifying metrics=['accuracy'] is robust enough.
Question
What is happening behind the scenes when the user specifies metrics=['accuracy'] in the model compilation step that allows the accuracy computation to be performed correctly regardless of whether the model outputs logits or probabilities and whether the ground truth labels are one-hot-encoded vectors or integer indices?
I suspect that the logits versus probabilities case is simple since the predicted class can be obtained as an argmax either way, but I would ideally like to be pointed to where in the TensorFlow 2 source code the computation is actually done.
Please note that I am currently using TensorFlow 2.0.0-rc1.
Edit
In pure Keras, metrics=['accuracy'] is handled explicitly in the Model.compile method.
Found it: this is handled in tensorflow.python.keras.engine.training_utils.get_metric_function. In particular, the output shape is inspected to determine which accuracy function to use.
To elaborate, in the current implementation Model.compile either delegates metric processing to Model._compile_eagerly (if executing eagerly) or does it directly. In either case, Model._cache_output_metric_attributes is called, which calls collect_per_output_metric_info for both the unweighted and weighted metrics. This function loops over the provided metrics, calling get_metric_function on each one.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With