Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: Batch TFRecord Dataset with tensors of arbitrary dimensions

How can I batch tensors of arbitrary shapes using the TFRecordsDataset?

I am currently working on the input pipeline of an object detection Network and am struggeling with the batching of my labels. The labels consist of Bounding Box coordinates and classes of the Objects in an image. Since there may be multiple objects in an image the label dimensions are arbitrary


When working with tf.train.batch there is the possibility to set dynamic_padding=True to fit the shapes to the same dimensions. However there is no such option in the data.TFRecordDataset.batch().

The desired shape i would like to batch would be [batch_size, arbitrary , 4] for my Boxes and [batch_size, arbitrary, 1] for the classes.

def decode(serialized_example):
"""
Decodes the information of the TFRecords to image, label_coord, label_classes
Later on will also contain the Image Sequence!

:param serialized_example: Serialized Example read from the TFRecords
:return: image, label_coordinates list, label_classes list
"""
features = {'image/shape': tf.FixedLenFeature([], tf.string),
            'train/image': tf.FixedLenFeature([], tf.string),
            'label/coordinates': tf.VarLenFeature(tf.float32),
            'label/classes': tf.VarLenFeature(tf.string)}

features = tf.parse_single_example(serialized_example, features=features)

image_shape = tf.decode_raw(features['image/shape'], tf.int64)
image = tf.decode_raw(features['train/image'], tf.float32)
image = tf.reshape(image, image_shape)

# Contains the Bounding Box coordinates in a flattened tensor
label_coord = features['label/coordinates']
label_coord = label_coord.values
label_coord = tf.reshape(label_coord, [1, -1, 4])

# Contains the Classes of the BBox in a flattened Tensor
label_classes = features['label/classes']
label_classes = label_classes.values
label_classes = tf.reshape(label_classes, [1, -1, 1])


return image, label_coord, label_classes

    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(decode)
    dataset = dataset.map(augment)
    dataset = dataset.map(normalize)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)

    dataset = dataset.batch(batch_size)

The error that is thrown is Cannot batch tensors with different shapes in component 1. First element had shape [1,1,4] and element 1 had shape [1,7,4].

Also currently the augment and normalize functions are just placeholders.

like image 418
Twald Avatar asked Dec 02 '25 23:12

Twald


1 Answers

It turns out tf.data.TFRecordDataset has an other function called padded_batch which is basically doing the thing tf.train.batch(dynamic_pad=True) does. This solves the proble rather easily...

dataset = tf.data.TFRecordDataset(filename)

dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)

dataset = dataset.shuffle(1000+3*batch_size)
dataset = dataset.repeat(num_epochs)
dataset = dataset.padded_batch(batch_size,
                               drop_remainder=False,
                               padded_shapes=([None, None, None],
                                              [None, 4],
                                              [None, 1])
                              )
like image 200
Twald Avatar answered Dec 05 '25 11:12

Twald



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!