How to extract data/labels back from TensorFlow dataset
Solution 1:
In case your tf.data.Dataset
is batched, the following code will retrieve all the y labels:
y = np.concatenate([y for x, y in ds], axis=0)
Solution 2:
Supposing our tf.data.Dataset is called train_dataset
, with eager_execution
on (default in TF 2.x), you can retrieve images and labels like this:
for images, labels in train_dataset.take(1): # only take first element of dataset
numpy_images = images.numpy()
numpy_labels = labels.numpy()
- the inline operation
.numpy()
converts tf.Tensors in numpy arrays - if you want to retrieve more elements of the dataset, just increase the number inside the take method. If you want all elements, just insert
-1
Solution 3:
I think we get a good example here:
https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/overview.ipynb#scrollTo=BC4pEXtkp4K-
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
# where mnsit train is a tf dataset
mnist_train = tfds.load(name="mnist", split=tfds.Split.TRAIN)
assert isinstance(mnist_train, tf.data.Dataset)
mnist_example, = mnist_train.take(1)
image, label = mnist_example["image"], mnist_example["label"]
plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())
So each individual component of the dataset can be accessed sort of like a dictionary. Presumably different datasets have different field names (Boston housing won't have image, and value, but might have 'features' and 'target' or 'price':
cnn = tfds.load(name="cnn_dailymail", split=tfds.Split.TRAIN)
assert isinstance(cnn, tf.data.Dataset)
cnn_ex, = cnn.take(1)
print(cnn_ex)
returns a dict() with keys ['article', 'highlight'] with numpy strings inside.
Solution 4:
If you are OK with keeping the images and labels as tf.Tensor
s, you can do
images, labels = tuple(zip(*dataset))
Think of the effect of the dataset as zip(images, labels)
. When we want to get images and labels back, we can simply unzip it.
If you need the numpy array version, convert them using np.array()
:
images = np.array(images)
labels = np.array(labels)
Solution 5:
Here is my own solution to the problem:
def dataset2numpy(dataset, steps=1):
"Helper function to get data/labels back from TF dataset"
iterator = dataset.make_one_shot_iterator()
next_val = iterator.get_next()
with tf.Session() as sess:
for _ in range(steps):
inputs, labels = sess.run(next_val)
yield inputs, labels
Please note that this function will yield inputs/labels of dataset batch. The steps control how many batches from a dataset will be taken out.