Extract patches from list of images using tensorflow

How we can extract patches if we have a list of images.

example:

def get_train_images():
    image_list = glob(FLAGS.train_path + '/*.jpg')
    #extract_patch

Want to do something like this: For example, I did this for only 1 image, but I wanted to do the same task for 100 images.

sample image(s): sample image

output image of sample image(s): Output images

I have a list of images and want to extract patches from an image and save them in another list. And that list can be overwritten.


Here is an example of two images in a list. The image patches are extracted for each image, and the end result is an array of 4 patches per image, hence the shape (2, 4, 4, 3) of patched_images, where 2 is the number of samples, 4 is the number of patches per image, and (4, 4, 3) is the shape of each patch image.

import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

images = [tf.random.normal((16, 16, 3)), tf.random.normal((16, 16, 3))]

patched_images = []
for img in images:
  image = tf.expand_dims(np.array(img), 0)
  patches = tf.image.extract_patches(images=image,
                          sizes=[1, 4, 4, 1],
                          strides=[1, 4, 4, 1],
                          rates=[1, 1, 1, 1],
                          padding='VALID')
  patches = [tf.reshape(patches[0, i, i], (4, 4, 3)) for i in range(4)]
  patched_images.append(np.asarray(patches))

patched_images = np.asarray(patched_images)
print(patched_images.shape)

axes=[]
fig=plt.figure()
patched_image = patched_images[0] # plot patches of first image
for i in range(4):
    axes.append( fig.add_subplot(2, 2, i + 1) )
    subplot_title=("Patch "+str(i + 1))
    axes[-1].set_title(subplot_title)  
    plt.imshow(patched_image[i, :, :, :])
fig.tight_layout()    
plt.show()
(2, 4, 4, 4, 3)

enter image description here

If you have different image sizes and still want to extract 4x4 patches regardless of the size of the images, try this:

import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

images = [tf.random.normal((16, 16, 3)), tf.random.normal((24, 24, 3)), tf.random.normal((180, 180, 3))]

patched_images = []
for img in images:
  image = tf.expand_dims(np.array(img), 0)
  patches = tf.image.extract_patches(images=image,
                          sizes=[1, 4, 4, 1],
                          strides=[1, 4, 4, 1],
                          rates=[1, 1, 1, 1],
                          padding='VALID')
  patches = [tf.reshape(patches[0, i, i], (4, 4, 3)) for i in range(4)]
  patched_images.append(np.asarray(patches))