I want to use my Tensorflow algorithm in an Android app. The Tensorflow Android example starts by downloading a GraphDef that contains the model definition and weights (in a *.pb file). Now this should be from my Scikit Flow algorithm (part of Tensorflow).
At the first glance it seems easy you just have to say classifier.save('model/') but the files saved to that folder are not *.ckpt, *.def and certainly not *.pb. Instead you have to deal with a *.pbtxt and a checkpoint (without ending) file.
I'm stuck there since quite a while. Here a code example to export something:
#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)
The files you get are:
Many possible workarounds I found would require having the GraphDef in a variable (don't know how with Scikit Flow). Or a Tensorflow session which doesn't seem to be required using Scikit Flow.
To save as pb file, you need to extract the graph_def from the constructed graph. You can do that as--
from tensorflow.python.framework import tensor_shape, graph_util
from tensorflow.python.platform import gfile
sess = tf.Session()
final_tensor_name = 'results:0' #Replace final_tensor_name with name of the final tensor in your graph
#########Build your graph and train########
## Your tensorflow code to build the graph
###########################################
outpt_filename = 'output_graph.pb'
output_graph_def = sess.graph.as_graph_def()
with gfile.FastGFile(outpt_filename, 'wb') as f:
f.write(output_graph_def.SerializeToString())
If you want to convert your trained variables to constants (to avoid using ckpt files to load the weights), you can use:
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [final_tensor_name])
Hope this helps!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With