Fit tf.data.Dataset with from_generator

Solution 1:

If you checked the resulting input_batch and output_batch from your train_dataset, it has got to work.
In my simple reproductive example,

import tensorflow as tf
import tensorflow.keras as keras

x = tf.random.normal((100, 4, 480, 480, 1)) # Say I have 100 data.
y = tf.random.normal((100, 4, 480, 480, 1))

train_ds = tf.data.Dataset.from_tensor_slices((x, y))
train_ds = train_ds.batch(batch_size=32, drop_remainder=True)

# check whether this dataset really produces the things I want
sample_input_batch, sample_output_batch = next(iter(train_ds))
print(sample_input_batch.shape) # (32, 4, 480, 480, 1)
print(sample_output_batch.shape) # (32, 4, 480, 480, 1)

simple_model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(4, 480, 480, 1)),
    keras.layers.Dense(10, activation='relu'),
    keras.layers.Dense(1, activation='relu')
])

simple_model.compile(loss=keras.losses.MeanSquaredError(),
                    optimizer=keras.optimizers.Adam(),
                    metrics=['mse'])

# check whether the model really produces the thing I expect
sample_predicted_batch = simple_model(sample_input_batch)
print(sample_predicted_batch.shape) # (32, 4, 480, 480, 1)

simple_model.fit(train_ds)
# 3/3 [==============================] - 0s 24ms/step - loss: 1.1743 - mse: 1.1743
# Then it should work!

Besides, if you have a huge dataset, you don't really need to use repeat() method of tf.data.API.

Moreover, you've got to specify steps_per_epoch argument to some number if you are using repeat() method otherwise it means to run forever. You can check about steps_per_epoch in here.