Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using height, width information stored in a TFRecords file to set shape of a Tensor

Tags:

tensorflow

I have converted a directory of images and their labels into a TFRecords file, the feature maps include image_raw, label, height, width and depth. The function is as follows:

def convert_to_tfrecords(data_samples, filename):
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    writer = tf.python_io.TFRecordWriter(filename)
    for fname, lb in data_samples:
        im = cv2.imread(fname, cv2.IMREAD_UNCHANGED)
        image_raw = im.tostring()
        feats = tf.train.Features(
            feature =
            {
                'image_raw': _bytes_feature(image_raw),
                'label': _int64_feature(int(lb)),
                'height': _int64_feature(im.shape[0]),
                'width': _int64_feature(im.shape[1]),
                'depth': _int64_feature(im.shape[2])
            }
        )
        example = tf.train.Example(features=feats)
        writer.write(example.SerializeToString())
    writer.close()

Now, I would like to read this TFRecords file to feed a input pipeline. However, since image_raw has been flattened, we need to reshape it into the original [height, width, depth] size. So how can I get the values of height, width and depth from the TFRecords file? It seems the following code cannot work because height is a Tensor without values.

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    feats = {
        'image_raw': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64),
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64),
        'depth': tf.FixedLenFeature([], tf.int64)
    }
    features = tf.parse_single_example(serialized_example, features=feats)
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    label = tf.cast(features['label'], tf.int32)
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    depth = tf.cast(features['depth'], tf.int32)
    image = tf.reshape(image, [height, width, depth]) # <== not work
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    return image, label

When I read the Tensorflow's official documents, I found they usually pass into a known size, saying [224,224,3]. However, I don't like it, because this information has been stored into the TFRecords file, and manually passing into fixed size cannot ensure the size is consistent with the data stored in the file.

So any ideas?

like image 363
C. Wang Avatar asked Jan 27 '26 13:01

C. Wang


1 Answers

The height returned by tf.parse_single_example is a Tensor, and the only way to get its value is to call session.run() on it, or similar. However, I think that's overkill.

Since the Tensorflow example is just a protocol buffer (see the documentation), you don't necessarily have to use tf.parse_single_example to read it. You could instead parse it yourself and read the shapes you want out directly.

You might also consider filing a feature request on Tensorflow's github issues tracker --- I agree this API seems a bit awkward for this use case.

like image 170
Peter Hawkins Avatar answered Jan 30 '26 09:01

Peter Hawkins



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!