This question has been asked here, difference is my problem is focused on Estimator.
Some context: We have trained a model using estimator and get some variable defined within Estimator input_fn, this function preprocesses data to batches. Now, we are moving to prediction. During the prediction, we use the same input_fn to read in and process the data. But got error saying variable (word_embeddings) does not exist (variables exist in the chkp graph), here's the relevant bit of code in input_fn:
with tf.variable_scope('vocabulary', reuse=tf.AUTO_REUSE):
    if mode == tf.estimator.ModeKeys.TRAIN:
        word_to_index, word_to_vec = load_embedding(graph_params["word_to_vec"])
        word_embeddings = tf.get_variable(initializer=tf.constant(word_to_vec, dtype=tf.float32),
                                          trainable=False,
                                          name="word_to_vec",
                                          dtype=tf.float32)
    else:
        word_embeddings = tf.get_variable("word_to_vec", dtype=tf.float32)
basically, when it's in prediction mode, else is invoked to load up variables in checkpoint. Failure of recognizing this variable indicates a) inappropriate usage of scope; b) graph is not restored. I don't think scope matters that much here as long as reuse is set properly.
I suspect that is because the graph is not yet restored at input_fn phase. Usually, the graph is restored by calling saver.restore(sess, "/tmp/model.ckpt") reference. Investigation of estimator source code doesn't get me anything relating to restore, the best shot is MonitoredSession, a wrapper of training. It's already been stretch so much from the original problem, not confident if I'm on the right path, I'm looking for help here if anyone has any insights.
One line summary of my question: How does graph get restored within tf.estimator, via input_fn or model_fn?
As far as I know, Variable is the default operation for making a variable, and get_variable is mainly used for weight sharing. On the one hand, there are some people suggesting using get_variable instead of the primitive Variable operation whenever you need a variable.
The function tf. get_variable() returns the existing variable with the same name if it exists, and creates the variable with the specified shape and initializer if it does not exist.
TF Lattice Custom Estimators. Graph-based Neural Structured Learning in TFX. The Estimator object wraps a model which is specified by a model_fn , which, given inputs and a number of other parameters, returns the ops necessary to perform training, evaluation, or predictions. All outputs (checkpoints, event files, etc.)
Hi I think that you error comes simply because you didn't specify the shape in the tf.get_variable (at predict) , it seems that you need to specify the shape even if the variable is going to be restored.
I've made the following test with a simple linear regressor estimator that simply needs to predict x + 5
def input_fn(mode):
    def _input_fn():
        with tf.variable_scope('all_input_fn', reuse=tf.AUTO_REUSE):
            if mode == tf.estimator.ModeKeys.TRAIN:
                var_to_follow = tf.get_variable('var_to_follow', initializer=tf.constant(20))
                x_data = np.random.randn(1000)
                labels = x_data + 5
                return {'x':x_data}, labels
            elif mode == tf.estimator.ModeKeys.PREDICT:
                var_to_follow = tf.get_variable("var_to_follow", dtype=tf.int32, shape=[])
                return {'x':[0,10,100,var_to_follow]}
    return _input_fn
featcols = [tf.feature_column.numeric_column('x')]
model = tf.estimator.LinearRegressor(featcols, './outdir')
This code works perfectly fine, the value of the const is 20 and also for fun use it in my test set to confirm :p
However if you remove the shape=[] , it breaks, you can also give another initializer such as tf.constant(500) and everything will work and 20 will be used.
By running
model.train(input_fn(tf.estimator.ModeKeys.TRAIN), max_steps=10000)
and
preds = model.predict(input_fn(tf.estimator.ModeKeys.PREDICT))
print(next(preds))
You can visualize the graph and you'll see that a) the scoping is normal and b) the graph is restored.
Hope this will help you.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With