Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Normalizing BatchDataset in Tensorflow 2.3

I'm using tf.keras.preprocessing.image_dataset_from_directory from TF 2.3 to load images from directories (train/test split). What I get is a tf.data.Dataset (tensorflow.python.data.ops.dataset_ops.BatchDatasetactually) object with shapes:

train_ds.take(1)
# <TakeDataset shapes: ((None, 256, 256, 3), (None, 6)), types: (tf.float32, tf.float32)>
for images, labels in train_ds.take(1):
    print(images.shape)
    print(images[0])
# (32, 256, 256, 3)
# tf.Tensor(
# [[[225.75  225.75  225.75 ]
#   [225.75  225.75  225.75 ]
#   [225.75  225.75  225.75 ]
#   ...
#   [215.    214.    209.   ]
#   [215.    214.    209.   ]
#   [215.    214.    209.   ]]
#
#  ...], shape=(256, 256, 3), dtype=float32)

I cannot figure out how to normalize images (/= 255) with that Dataset object. I tried playing with /= operator itself, map and apply methods and even casting that object to list as mentioned here. Nothing seems to work and I would really like to solve this problem at Dataset level instead of adding normalization layer to my network.

Any ideas?

like image 751
mtszkw Avatar asked Sep 08 '25 11:09

mtszkw


1 Answers

Try this way:

def process(image,label):
    image = tf.cast(image/255. ,tf.float32)
    return image,label

ds = tf.keras.preprocessing.image_dataset_from_directory(IMAGE_DIR)
ds = ds.map(process)

like image 97
Shubham Shaswat Avatar answered Sep 10 '25 04:09

Shubham Shaswat