Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Resize MNIST in Tensorflow

I have been working on MNIST dataset to learn how to use Tensorflow and Python for my deep learning course.

I want to resize MNIST as 22 & 22 using tensorflow, then I train it, but I do not how to do?

Could you help me?

like image 901
Đông lv Avatar asked Oct 29 '25 16:10

Đông lv


2 Answers

Updated: TensorFlow 2.4.1


Short Answer

Use tf.image.resize (instead of resize_images). The link other provided no longer exits. Updated link.


Long Answer

MNIST in tf.keras.datasets.mnist is the following shape

(batch_size, 28 , 28)

Here is the full implementation. Please read the comment which attach with the code.

(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()

# expand new axis, channel axis 
x_train = np.expand_dims(x_train, axis=-1)

# [optional]: we may need 3 channel (instead of 1)
x_train = np.repeat(x_train, 3, axis=-1)

# it's always better to normalize 
x_train = x_train.astype('float32') / 255

# resize the input shape , i.e. old shape: 28, new shape: 32
x_train = tf.image.resize(x_train, [32,32]) # if we want to resize 

print(x_train.shape)
# (60000, 32, 32, 3)
like image 95
M.Innat Avatar answered Oct 31 '25 10:10

M.Innat


TheRevanchist's answer is correct. However, for the mnist dataset, you first need to reshape the mnist array before you send it to tf.image.resize_images():

import tensorflow as tf
import numpy as np
import cv2

mnist = tf.contrib.learn.datasets.load_dataset("mnist")

batch = mnist.train.next_batch(10)
X_batch = batch[0]
batch_tensor = tf.reshape(X_batch, [10, 28, 28, 1])
resized_images = tf.image.resize_images(batch_tensor, [22,22])

The code above takes out a batch of 10 mnist images and reshapes them from 28x28 images to 22x22 tensorflow images.

If you want to display the images, you can use opencv and the code below. The resized_images.eval() converts the tensorflow image to a numpy array!

with tf.Session() as sess:
    numpy_imgs = resized_images.eval(session=sess) # mnist images converted to numpy array
    for i in range(10):
        cv2.namedWindow('Resized image #%d' % i, cv2.WINDOW_NORMAL)
        cv2.imshow('Resized image #%d' % i, numpy_imgs[i])
        cv2.waitKey(0)
like image 30
Andreas Forslöw Avatar answered Oct 31 '25 08:10

Andreas Forslöw