How to extract samples from a Tensorflow Dataset which have the same label?
Solution 1:
Let's assume you have some kind of data with multiple labels like this:
x = tf.random.uniform((10, 2), maxval=20, dtype=tf.int32)
y = tf.random.uniform((10, ), maxval=4, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((x, y))
In order to achieve your goal. "each batch would consist of two samples with have the same 'label' feature" you could use the filter
function to create separate datasets for each label and then the concatenate
function of the tf.data.Dataset
API to merge these datasets into one:
x = tf.random.uniform((10, 2), maxval=20, dtype=tf.int32)
y = tf.random.uniform((10, ), maxval=4, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((x, y))
batch_size = 2
dataset0 = dataset.filter(lambda x, y: tf.equal(y, 0)).batch(batch_size)
dataset1 = dataset.filter(lambda x, y: tf.equal(y, 1)).batch(batch_size)
dataset2 = dataset.filter(lambda x, y: tf.equal(y, 2)).batch(batch_size)
dataset3 = dataset.filter(lambda x, y: tf.equal(y, 3)).batch(batch_size)
dataset = dataset0.concatenate(dataset1).concatenate(dataset2).concatenate(dataset3).shuffle(buffer_size=20)