Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow cond doesn't stop gradient on the false branch

Tags:

tensorflow

I am building a RNN model, where the init_state could come from one of two cases. 1) a static init_state that is fed in via feed_dict, from the previous time-step output state. 2) some function of a variable, which I call score.

init_state = cell.zero_state(batch,tf.float32)
with tf.name_scope('hidden1'):
     weights_h1 = tf.Variable(
                        tf.truncated_normal([T, cells_dim],
                        stddev=1.0 / np.sqrt(T)),
                        name='weights')
     biases_h1 = tf.Variable(tf.zeros([cells_dim]),
                        name='biases')
     hidden1 = tf.nn.relu(tf.matmul(score, weights_h1) + biases_h1)

init_state2 = tf.cond(is_start, lambda: hidden1, lambda: init_state)

init_state2 is then used as input to static_rnn, which eventually is used to calculate loss and train_op. I would expect the train_op to have no impact on weights_h1 when is_start is False. However, the weight changes after each updates. Any help is greatly appreciated.

like image 952
Will Avatar asked Jan 25 '26 15:01

Will


1 Answers

This should work:

def return_init_state():
    init_state = cell.zero_state(batch,tf.float32)
    return init_state

def return_hidden_1():
    with tf.name_scope('hidden1'):
        weights_h1 = tf.Variable(
                            tf.truncated_normal([T, cells_dim],
                            stddev=1.0 / np.sqrt(T)),
                            name='weights')
        biases_h1 = tf.Variable(tf.zeros([cells_dim]),
                            name='biases')
        hidden1 = tf.nn.relu(tf.matmul(score, weights_h1) + biases_h1)

        return hidden1

init_state2 = tf.cond(is_start, lambda: return_hidden_1, lambda: return_init_state)

Notice how the methods are called within the context of tf.cond. Therefore, whatever op is created, will be within the context of tf.cond. Otherwise, in your case, the ops will run either ways.

like image 185
I. A Avatar answered Jan 27 '26 23:01

I. A