Split a dataset created by Tensorflow dataset API in to Train and Test?
Does anyone know how to split a dataset created by the dataset API (tf.data.Dataset) in Tensorflow into Test and Train?
Solution 1:
Assuming you have all_dataset
variable of tf.data.Dataset
type:
test_dataset = all_dataset.take(1000)
train_dataset = all_dataset.skip(1000)
Test dataset now has first 1000 elements and the rest goes for training.
Solution 2:
You may use Dataset.take()
and Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.
Take:
Creates a Dataset with at most count elements from this dataset.
Skip:
Creates a Dataset that skips count elements from this dataset.
You may also want to look into Dataset.shard()
:
Creates a Dataset that includes only 1/num_shards of this dataset.
Disclaimer I stumbled upon this question after answering this one so I thought I'd spread the love
Solution 3:
Most of the answers here use take()
and skip()
, which requires knowing the size of your dataset before hand. This isn't always possible, or is difficult/intensive to ascertain.
Instead what you can do is to essentially slice the dataset up so that 1 every N records becomes a validation record.
To accomplish this, lets start with a simple dataset of 0-9:
dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Now for our example, we're going to slice it so that we have a 3/1 train/validation split. Meaning 3 records will go to training, then 1 record to validation, then repeat.
split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]
So the first dataset.window(split, split + 1)
says to grab split
number (3) of elements, then advance split + 1
elements, and repeat. That + 1
effectively skips the 1 element we're going to use in our validation dataset.
The flat_map(lambda ds: ds)
is because window()
returns the results in batches, which we don't want. So we flatten it back out.
Then for the validation data we first skip(split)
, which skips over the first split
number (3) of elements that were grabbed in the first training window, so we start our iteration on the 4th element. The window(1, split + 1)
then grabs 1 element, advances split + 1
(4), and repeats.
Note on nested datasets:
The above example works well for simple datasets, but flat_map()
will generate an error if the dataset is nested. To address this, you can swap out the flat_map()
with a more complicated version that can handle both simple and nested datasets:
.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
Solution 4:
@ted's answer will cause some overlap. Try this.
train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)
train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)
use code below to test.
tf.enable_eager_execution()
dataset = tf.data.Dataset.range(100)
train_size = 20
valid_size = 30
test_size = 50
train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)
for i in train:
print(i)
for i in valid:
print(i)
for i in test:
print(i)