How to get samples per class for TensorFlow Dataset
With np.fromiter
you can create a 1-D array from an iterable object.
import tensorflow_datasets as tfds
import numpy as np
import seaborn as sns
dataset = tfds.load('cifar10', split='train', as_supervised=True)
labels, counts = np.unique(np.fromiter(dataset.map(lambda x, y: y), np.int32),
return_counts=True)
plt.ylabel('Counts')
plt.xlabel('Labels')
sns.barplot(x = labels, y = counts)
Update: You can also count the labels like below:
labels = []
for x, y in dataset:
# Not one hot encoded
labels.append(y.numpy())
# If one hot encoded, then apply argmax
# labels.append(np.argmax(y, axis = -1))
labels = np.concatenate(labels, axis = 0) # Assuming dataset was batched.
Then you can plot them using the labels
array.