Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Wrong output for restored variable in tensorflow graph

I'm currently toying around with saving and restoring of variables. For this purpose, I created two scripts. One of them saves a simple graph while the other restores it. Here the test script for saving the graph:

import tensorflow as tf

a = tf.Variable(3.0, name='a')
b = tf.Variable(5.0, name='b')

b = tf.assign_add(b, a)

n_steps = 5

global_step = tf.Variable(0, name='global_step', trainable=False)

saver = tf.train.Saver()

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())

    for step in range(n_steps):
        print(sess.run(b))

        global_step.assign_add(1).eval()
        print(global_step.eval())

        saver.save(sess, './my_test_model', global_step=global_step)

Basically, I want to run through a loop 5 times and everytime I do this, I add a to b. I also want to keep track of the number of steps via global_step. This works as intended. The output is:

8.0     # value of b
1       # step
11.0
2
14.0
3
17.0
4
20.0
5

Now when restoring the variables, I try to get all three of them. Script is:

import tensorflow as tf

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# List ALL tensors.
print_tensors_in_checkpoint_file(tf.train.latest_checkpoint('./'), all_tensors=True, tensor_name='')

tf.reset_default_graph()

a = tf.get_variable('a', shape=[])
b = tf.get_variable('b', shape=[])
global_step = tf.get_variable('global_step', shape=[])

saver = tf.train.Saver()

with tf.Session() as sess:

    ckpt = tf.train.latest_checkpoint('./')
    if ckpt:
        print(ckpt)

        saver.restore(sess, ckpt)

    else:
        print('Nothing restored')

    print(a.eval())
    print(b.eval())
    print(global_step.eval())

The output of this is

tensor_name:  a
3.0
tensor_name:  b
20.0
tensor_name:  global_step
5
./my_test_model-5
INFO:tensorflow:Restoring parameters from ./my_test_model-5
3.0
20.0
7e-45

How is it possible that the value for global_step is stored correctly in the checkpoint, but upon evaluation I get this small 7e-45? Also, upon restoring, I seem to be unable to define any additional variables as it states it cannot find the variable in the checkpoint. How can I, for example, define a variable and add it to the b of the restored graph?

Thank you for your help!

like image 621
DocDriven Avatar asked Jan 18 '26 09:01

DocDriven


1 Answers

This doesn't appear to be well documented by the TF docs, but you should specify the dtype for the global_step variable.

Incorrect

global_step = tf.get_variable('global_step', shape=[], dtype=tf.float32) results in global_step=7e-5. The type is assumed to be dtf.float32 by default.

Correct

global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32) results in global_step=5

like image 115
kww Avatar answered Jan 20 '26 02:01

kww



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!