Feeding a 2D image to a TensorFlow CNN for image classification
I think you need to change a few things: the input shape to your model does not need the batch_size
, it will be inferred during training. Change it to (400, 400, 3)
. Second, if you are working with binary labels, you need to change your loss function to tf.keras.losses.BinaryCrossentropy
and your metric to tf.keras.metrics.BinaryAccuracy
or simply accuracy
. Furthermore, your output layer should have one output node instead of two: tf.keras.layers.Dense(1)
Here is a running example based on your code:
import numpy as np
import tensorflow as tf
no_of_samples = 250
BATCH_SIZE = 16
SHUFFLE_BUFFER_SIZE = 50
data, labels = np.random.random((no_of_samples, 400, 400, 3)), np.random.randint(2, size=no_of_samples)
dataset = tf.data.Dataset.from_tensor_slices((data, labels)).shuffle(SHUFFLE_BUFFER_SIZE)
test_dataset = dataset.take(50).batch(BATCH_SIZE)
train_dataset = dataset.skip(50).batch(BATCH_SIZE)
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(200, 5, strides=3, activation='relu', input_shape=(400, 400, 3)),
tf.keras.layers.Conv2D(100, 5, strides=2, activation="relu"),
tf.keras.layers.Conv2D(50, 5, activation="relu"),
tf.keras.layers.Conv2D(25, 3, activation="relu"),
tf.keras.layers.MaxPooling2D(3),
tf.keras.layers.Conv2D(50, 3, activation="relu"),
tf.keras.layers.Conv2D(25, 3, activation="relu"),
tf.keras.layers.MaxPooling2D(3),
tf.keras.layers.Conv2D(50, 2, activation="relu"),
tf.keras.layers.Conv2D(25, 2, activation="relu"),
tf.keras.layers.GlobalMaxPooling2D(),
# Finally, we add a classification layer.
tf.keras.layers.Dense(1)
])
model.compile(optimizer=tf.keras.optimizers.RMSprop(),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_dataset, epochs=10, validation_data=test_dataset)
print('Labels shape -->',labels.shape)
print('Labels -->', labels)
Labels shape --> (250,)
Labels --> [1 0 0 0 0 1 0 1 0 1 1 1 0 1 1 0 0 1 0 0 1 0 0 1 0 1 1 0 0 1 1 1 0 1 1 0 0
0 1 0 1 1 1 0 1 1 1 1 0 0 0 0 0 1 1 1 0 1 1 0 0 0 1 1 0 1 1 0 0 1 1 1 0 0
1 0 1 1 1 1 1 1 1 1 0 0 1 1 0 1 1 1 1 0 1 1 0 0 0 1 0 1 1 1 0 1 0 1 1 0 1
1 1 1 1 0 0 0 1 0 0 0 1 1 1 0 1 1 1 0 0 0 1 1 1 0 1 0 1 0 1 1 0 1 0 0 1 0
1 1 1 1 0 0 1 0 0 0 0 0 0 0 0 1 1 0 1 1 1 0 0 0 0 1 1 0 0 1 1 1 0 0 0 0 1
0 0 0 0 1 1 0 1 1 1 0 0 0 0 1 0 1 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 1 0 1
0 0 1 1 1 1 1 0 0 0 1 0 0 1 0 1 1 1 1 1 0 1 1 1 0 1 1 0]