Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to restore a partial graph in tensorflow?

I would like to restore only the part of computation graph in tensorflow. My architecture contains two networks. The output of the first network is the input to the second network. The first network is pretrained and I want to restore from a checkpoint. I don't want to update the parameters of the first network as well. Is there an example that I can follow to achieve this?

Thanks

like image 799
Sentient07 Avatar asked Oct 28 '25 13:10

Sentient07


1 Answers

I don't have exact code for you task, but here is a short guide that may help you:

First you need to parse your network into tf.GraphDef format code should like this:

graph_def = tf.GraphDef()
with tf.gfile.FastGFile("path/to/graphdef") as f:
  s = f.read()
graph_def.ParseFromString(s)

or restore from a checkpoint/saved_mode then convert to GraphDef by:

tf.train.import_meta_graph('checkpoint.meta')
tf.get_default_graph().as_graph_def()

now you have the graph_def

Second, extract subgraph of the graph_def with tf.graph_util.extract_sub_graph, you can specify the dest nodes which are you inputs to the second network as well.

Last, import the subgraph from second step with tf.import_graph_def.

Also, since you don't want to update the parameters for the first network, you can freeze its parameters with tf.graph_util.convert_variables_to_constants

like image 179
Jie.Zhou Avatar answered Oct 31 '25 03:10

Jie.Zhou



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!