Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow Scikit Flow get GraphDef for Android (save *.pb file)

I want to use my Tensorflow algorithm in an Android app. The Tensorflow Android example starts by downloading a GraphDef that contains the model definition and weights (in a *.pb file). Now this should be from my Scikit Flow algorithm (part of Tensorflow).

At the first glance it seems easy you just have to say classifier.save('model/') but the files saved to that folder are not *.ckpt, *.def and certainly not *.pb. Instead you have to deal with a *.pbtxt and a checkpoint (without ending) file.

I'm stuck there since quite a while. Here a code example to export something:

#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics

#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)

The files you get are:

  • checkpoint
  • graph.pbtxt
  • model.ckpt-1.meta
  • model.ckpt-1-00000-of-00001
  • model.ckpt-200.meta
  • model.ckpt-200-00000-of-00001

Many possible workarounds I found would require having the GraphDef in a variable (don't know how with Scikit Flow). Or a Tensorflow session which doesn't seem to be required using Scikit Flow.

like image 805
CodingYourLife Avatar asked Dec 05 '25 03:12

CodingYourLife


1 Answers

To save as pb file, you need to extract the graph_def from the constructed graph. You can do that as--

from tensorflow.python.framework import tensor_shape, graph_util
from tensorflow.python.platform import gfile
sess = tf.Session()
final_tensor_name = 'results:0'     #Replace final_tensor_name with name of the final tensor in your graph
#########Build your graph and train########
## Your tensorflow code to build the graph
###########################################

outpt_filename = 'output_graph.pb'
output_graph_def = sess.graph.as_graph_def()
with gfile.FastGFile(outpt_filename, 'wb') as f:
  f.write(output_graph_def.SerializeToString())

If you want to convert your trained variables to constants (to avoid using ckpt files to load the weights), you can use:

output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [final_tensor_name])

Hope this helps!

like image 130
dd.ai Avatar answered Dec 07 '25 22:12

dd.ai



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!