Fit tf.data.Dataset with from_generator
Solution 1:
If you checked the resulting input_batch
and output_batch
from your train_dataset
, it has got to work.
In my simple reproductive example,
import tensorflow as tf
import tensorflow.keras as keras
x = tf.random.normal((100, 4, 480, 480, 1)) # Say I have 100 data.
y = tf.random.normal((100, 4, 480, 480, 1))
train_ds = tf.data.Dataset.from_tensor_slices((x, y))
train_ds = train_ds.batch(batch_size=32, drop_remainder=True)
# check whether this dataset really produces the things I want
sample_input_batch, sample_output_batch = next(iter(train_ds))
print(sample_input_batch.shape) # (32, 4, 480, 480, 1)
print(sample_output_batch.shape) # (32, 4, 480, 480, 1)
simple_model = keras.Sequential([
keras.layers.InputLayer(input_shape=(4, 480, 480, 1)),
keras.layers.Dense(10, activation='relu'),
keras.layers.Dense(1, activation='relu')
])
simple_model.compile(loss=keras.losses.MeanSquaredError(),
optimizer=keras.optimizers.Adam(),
metrics=['mse'])
# check whether the model really produces the thing I expect
sample_predicted_batch = simple_model(sample_input_batch)
print(sample_predicted_batch.shape) # (32, 4, 480, 480, 1)
simple_model.fit(train_ds)
# 3/3 [==============================] - 0s 24ms/step - loss: 1.1743 - mse: 1.1743
# Then it should work!
Besides, if you have a huge dataset, you don't really need to use repeat()
method of tf.data.API
.
Moreover, you've got to specify steps_per_epoch
argument to some number if you are using repeat()
method otherwise it means to run forever. You can check about steps_per_epoch
in here.