Refer to # https://github.com/tensorflow/tensorflow/issues/32875
The suggested fix was to :
class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
    @tf.function
    def __call__(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1) # this is the fix
        return super().__call__(y_true, y_pred, sample_weight=sample_weight)
It worked for TF2.1, but broke again in TF2.2. Is there a way to pass y_pred = tf.argmax(y_pred, axis=-1) as y_pred to this metric other than subclassing ? 
This fixes the issue:
class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
  def __init__(self,
               y_true=None,
               y_pred=None,
               num_classes=None,
               name=None,
               dtype=None):
    super(UpdatedMeanIoU, self).__init__(num_classes = num_classes,name=name, dtype=dtype)
  def update_state(self, y_true, y_pred, sample_weight=None):
    y_pred = tf.math.argmax(y_pred, axis=-1)
    return super().update_state(y_true, y_pred, sample_weight)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With