Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow Keras 'accuracy' metric under the hood implementation

When building a classifier using TensorFlow Keras, one often monitors model accuracy by specifying metrics=['accuracy'] during the compilation step:

model = tf.keras.Model(...)
model.compile(optimizer=..., loss=..., metrics=['accuracy'])

This behaves correctly whether or not the model outputs logits or class probabilities, and whether or not the model expects ground truth labels to be one-hot-encoded vectors or integer indices (i.e., integers in the interval [0, n_classes)).

This is not the case if one wants to use cross-entropy loss: each of the four combinations of the cases mentioned above requires a different loss value to be passed during the compilation step:

  1. If the model outputs probabilities and the ground truth labels are one-hot-encoded, then loss='categorical_crossentropy' works.

  2. If the model outputs probabilities and the ground truth labels are integer indices, then loss='sparse_categorical_crossentropy' works.

  3. If the model outputs logits and the ground truth labels are one-hot-encoded, then loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True) works.

  4. If the model outputs logits and the ground truth labels are integer indices, then loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) works.

It seems that just specifying loss='categorical_crossentropy' is not robust enough to handle these four cases, whereas specifying metrics=['accuracy'] is robust enough.


Question What is happening behind the scenes when the user specifies metrics=['accuracy'] in the model compilation step that allows the accuracy computation to be performed correctly regardless of whether the model outputs logits or probabilities and whether the ground truth labels are one-hot-encoded vectors or integer indices?


I suspect that the logits versus probabilities case is simple since the predicted class can be obtained as an argmax either way, but I would ideally like to be pointed to where in the TensorFlow 2 source code the computation is actually done.

Please note that I am currently using TensorFlow 2.0.0-rc1.


Edit In pure Keras, metrics=['accuracy'] is handled explicitly in the Model.compile method.

like image 762
Artem Mavrin Avatar asked Dec 15 '25 05:12

Artem Mavrin


1 Answers

Found it: this is handled in tensorflow.python.keras.engine.training_utils.get_metric_function. In particular, the output shape is inspected to determine which accuracy function to use.

To elaborate, in the current implementation Model.compile either delegates metric processing to Model._compile_eagerly (if executing eagerly) or does it directly. In either case, Model._cache_output_metric_attributes is called, which calls collect_per_output_metric_info for both the unweighted and weighted metrics. This function loops over the provided metrics, calling get_metric_function on each one.

like image 157
Artem Mavrin Avatar answered Dec 16 '25 20:12

Artem Mavrin



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!