Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: How to initialize variables on another graph?

Tags:

tensorflow

I have a default graph and a newly created graph (G1).

In G1, I have a variable named "a".

I can use tf.import_graph_def to include G1 onto the main graph, and expose its "a" variable.

How do I initialize this variable and successfully print the value of "a" ?

Here is the actual code:

import tensorflow as tf

INT = tf.int32


def graph():
    g = tf.Graph()
    with g.as_default() as g:
        a = tf.get_variable('a', [], INT, tf.constant_initializer(10))
    return g


tf.reset_default_graph()

g = graph()
[g_a] = tf.import_graph_def(g.as_graph_def(), return_elements=['a:0'])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(g_a))

The above won't work, it will error with FailedPreconditionError: Attempting to use uninitialized value import/a.

like image 221
Phizaz Avatar asked Mar 12 '26 06:03

Phizaz


1 Answers

The reason you get errors is that when you import a graph def, no variables and values are imported or restored.

You can use variables in another graph if you do the following:

  • declare your variable in a session, then run tf.global_variables_initalizer()
  • save your variable
  • after you import your graph_def, restore your variable
  • important: when you import the graph def use name='' to use the same namespace as in your other graph otherwise you get errors

A minimal example how to to this:

import tensorflow as tf

INT = tf.int32

def graph():
    g = tf.Graph()
    with tf.Session(graph=g) as sess:
        a = tf.get_variable("a", shape=[1], dtype=INT, initializer=tf.constant_initializer(10))
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.save(sess, './test_dir/test_save.ckpt')
        return g


g = graph()

tf.reset_default_graph()

g_a = tf.import_graph_def(g.as_graph_def(), return_elements=['a:0'], name='')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    second_saver = tf.train.Saver(var_list=g_a)
    second_saver.restore(sess, './test_dir/test_save.ckpt')
    a = sess.graph.get_tensor_by_name('a:0')
    print(sess.run(a))
like image 126
Noomi Avatar answered Mar 15 '26 18:03

Noomi



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!