I'm trying to adapt the tutorial Deep MNIST for Experts to detect just one class, let's say detect if an image contains or not a kitty.
This is the prediction part of my code:
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
sess.run(tf.initialize_all_variables())
for i in range(20000):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={
x:batch[0], y_: batch[1], keep_prob: 1.0})
print("step %d, training accuracy %g"%(i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
print("test accuracy %g"%accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
The problem is that with one class, the softmax always return that class with confidence of 1, even for a blank image. I tried modifying the softmax and cross entropy, but I wasn't able to solve it.
I need to know what approach is recommended for this problem. I want the prediction to be the probability of an image being a kitty.
I know that this can be solved using a second label trained with random images, but I need to know if there is a better solution.
Thank you very much.
Don't use softmax and multi-class logloss for a single class membership prediction. Instead, the more usual setup is sigmoid activation with binary cross entropy. Unless you are optimising cost/benefit of correct prediction*, just set a threshold value of > 0.5 to be classified as the "positive" class.
In TensorFlow, this changes your code in only a couple of places.
The following adjustments apply to the start of your code I think:
y_conv = tf.nn.sigmoid(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
# I've split the individual loss out because the line length was too long
# The +1e-12 is for numerical stability in case of perfect 0.0 or 1.0 predictions
# Note how this loss metric penalises incorrect predictions in both directions,
# unlike the multiclass logloss which only assessed confidence in
# correct class.
loss = -(y_ * tf.log(y_conv + 1e-12) + (1 - y_) * tf.log( 1 - y_conv + 1e-12))
cross_entropy = tf.reduce_mean(tf.reduce_sum(loss, reduction_indices=[1]))
predict_is_kitty = tf.greater(y_conv,0.5)
correct_prediction = tf.equal( tf.to_float(predict_is_kitty), y_ )
* If you are working on a problem where you care about confidence of the prediction, and need to assess where to set the threshold, the usual metric instead of accuracy is area under ROC curve, often known as AUROC or just AUC.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With