tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?
Let's say I have defined a dataset in this way:
filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))
how can I get the number of elements that are inside the dataset (hence, the number of single elements that compose an epoch)?
I know that tf.data.Dataset
already knows the dimension of the dataset, because the repeat()
method allows repeating the input pipeline for a specified number of epochs. So it must be a way to get this information.
Solution 1:
len(list(dataset))
works in eager mode, although that's obviously not a good general solution.
Solution 2:
Take a look here: https://github.com/tensorflow/tensorflow/issues/26966
It doesn't work for TFRecord datasets, but it works fine for other types.
TL;DR:
num_elements = tf.data.experimental.cardinality(dataset).numpy()
Solution 3:
UPDATE:
Use tf.data.experimental.cardinality(dataset)
- see here.
In case of tensorflow datasets you can use _, info = tfds.load(with_info=True)
. Then you may call info.splits['train'].num_examples
. But even in this case it doesn't work properly if you define your own split.
So you may either count your files or iterate over the dataset (like described in other answers):
num_training_examples = 0
num_validation_examples = 0
for example in training_set:
num_training_examples += 1
for example in validation_set:
num_validation_examples += 1
Solution 4:
As of TensorFlow (>=2.3
) one can use:
dataset.cardinality().numpy()
Note that the .cardinality()
method was integrated into the main package (before it was in the experimental
package).
Note that when applying the filter()
operation this operation can return -2
.