ResNet: 100% accuracy during training, but 33% prediction accuracy with the same data
I am new to machine learning and deep learning, and for learning purposes I tried to play with Resnet. I tried to overfit over small data (3 different images) and see if I can get almost 0 loss and 1.0 accuracy - and I did.
The problem is that predictions on the training images (i.e. the same 3 images used for training) are not correct..
Training Images
Image labels
[1,0,0]
, [0,1,0]
, [0,0,1]
My python code
#loading 3 images and resizing them
imgs = np.array([np.array(Image.open("./Images/train/" + fname)
.resize((197, 197), Image.ANTIALIAS)) for fname in
os.listdir("./Images/train/")]).reshape(-1,197,197,1)
# creating labels
y = np.array([[1,0,0],[0,1,0],[0,0,1]])
# create resnet model
model = ResNet50(input_shape=(197, 197,1),classes=3,weights=None)
# compile & fit model
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['acc'])
model.fit(imgs,y,epochs=5,shuffle=True)
# predict on training data
print(model.predict(imgs))
The model does overfit the data:
3/3 [==============================] - 22s - loss: 1.3229 - acc: 0.0000e+00
Epoch 2/5
3/3 [==============================] - 0s - loss: 0.1474 - acc: 1.0000
Epoch 3/5
3/3 [==============================] - 0s - loss: 0.0057 - acc: 1.0000
Epoch 4/5
3/3 [==============================] - 0s - loss: 0.0107 - acc: 1.0000
Epoch 5/5
3/3 [==============================] - 0s - loss: 1.3815e-04 - acc: 1.0000
but predictions are:
[[ 1.05677405e-08 9.99999642e-01 3.95520459e-07]
[ 1.11955103e-08 9.99999642e-01 4.14905685e-07]
[ 1.02637095e-07 9.99997497e-01 2.43751242e-06]]
which means that all images got label=[0,1,0]
why? and how can that happen?
Solution 1:
It's because of the batch normalization layers.
In training phase, the batch is normalized w.r.t. its mean and variance. However, in testing phase, the batch is normalized w.r.t. the moving average of previously observed mean and variance.
Now this is a problem when the number of observed batches is small (e.g., 5 in your example) because in the BatchNormalization
layer, by default moving_mean
is initialized to be 0 and moving_variance
is initialized to be 1.
Given also that the default momentum
is 0.99, you'll need to update the moving averages quite a lot of times before they converge to the "real" mean and variance.
That's why the prediction is wrong in the early stage, but is correct after 1000 epochs.
You can verify it by forcing the BatchNormalization
layers to operate in "training mode".
During training, the accuracy is 1 and the loss is close to zero:
model.fit(imgs,y,epochs=5,shuffle=True)
Epoch 1/5
3/3 [==============================] - 19s 6s/step - loss: 1.4624 - acc: 0.3333
Epoch 2/5
3/3 [==============================] - 0s 63ms/step - loss: 0.6051 - acc: 0.6667
Epoch 3/5
3/3 [==============================] - 0s 57ms/step - loss: 0.2168 - acc: 1.0000
Epoch 4/5
3/3 [==============================] - 0s 56ms/step - loss: 1.1921e-07 - acc: 1.0000
Epoch 5/5
3/3 [==============================] - 0s 53ms/step - loss: 1.1921e-07 - acc: 1.0000
Now if we evaluate the model, we'll observe high loss and low accuracy because after 5 updates, the moving averages are still pretty close to the initial values:
model.evaluate(imgs,y)
3/3 [==============================] - 3s 890ms/step
[10.745396614074707, 0.3333333432674408]
However, if we manually specify the "learning phase" variable and let the BatchNormalization
layers use the "real" batch mean and variance, the result becomes the same as what's observed in fit()
.
sample_weights = np.ones(3)
learning_phase = 1 # 1 means "training"
ins = [imgs, y, sample_weights, learning_phase]
model.test_function(ins)
[1.192093e-07, 1.0]
It's also possible to verify it by changing the momentum to a smaller value.
For example, by adding momentum=0.01
to all the batch norm layers in ResNet50
, the prediction after 20 epochs is:
model.predict(imgs)
array([[ 1.00000000e+00, 1.34882026e-08, 3.92139575e-22],
[ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00],
[ 8.70998792e-06, 5.31159838e-10, 9.99991298e-01]], dtype=float32)