Is it possible to create an input_fn that generates random data infinitely to use with the Estimator API in Tensorflow?
This is basically what I would like:
def create_input_fn(function_to_generate_one_sample_with_label):
def _input_fn():
### some code ###
return feature_cols, labels
I would then like to use the function with an Estimator instance like this:
def data_generator():
features = ... generate a (random) feature vector ...
lablel = ... create suitable label ...
return features, labels
input_fn = create_input_fn(data_generator)
estimator.train(input_fn=input_fn, steps=ANY_NUMBER_OF_STEPS)
The point is to be able to train for as many steps as needed, generating the required training data on the fly. This is for model tuning purposes, to be able to experiment with different training data of varying complexity so that I can get an idea of the capability of the model to fit the training data.
Edit As jkm suggested, I tried using an actual generator, like this:
def create_input_fn(function, batch_size=100):
def create_generator():
while True:
features = ... generate <batch_size> feature vectors ...
lablel = ... create <batch_size> labels ...
yield features, label
g = create_generator()
def _input_fn():
return next(g)
return _input_fn
I had to add a batch size to get it to run. It now runs, but input_fn is only called once, so it does not generate any new data. It just trains on the first <batch_size> samples that were generated. Is there some way to tell the estimator to refresh the data using the provided input_fn?
I think you can get the desired behavior using recent Tf Dataset API, you need tensorflow>=1.2.0
# Define number of samples and input shape for each iteration
# you can set minval or maxval as per you data distribution and label distributon requirements
num_samples = [20000,]
input_shape = [32, 32, 3]
dataset = tf.contrib.data.Dataset.from_tensor_slices((tf.random_normal([num_examples+input_shape]), tf.random_uniform([num_samples], minval=0, maxval=5)))
# Define batch_size
batch_size = 128
dataset = dataset.batch(batch_size)
# Define iterator
iterator = dataset.make_initializable_iterator()
# Get one batch
next_example, next_label = iterator.get_next()
# calculate loss from the estimator fucntion you are using
estimator_loss = some_estimator(next_example, next_label)
# Set number of Epochs here
num_epochs = 100
for _ in range(num_epochs):
sess.run(iterator.initializer)
while True:
try:
_loss = sess.run(estimator_loss)
except tf.errors.OutOfRangeError:
break
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With