ResNet: 100% accuracy during training, but 33% prediction accuracy with the same data

I am new to machine learning and deep learning, and for learning purposes I tried to play with Resnet. I tried to overfit over small data (3 different images) and see if I can get almost 0 loss and 1.0 accuracy - and I did.

The problem is that predictions on the training images (i.e. the same 3 images used for training) are not correct..

Training Images

image 1 image 2 image 3

Image labels

[1,0,0], [0,1,0], [0,0,1]

My python code

#loading 3 images and resizing them
imgs = np.array([np.array(Image.open("./Images/train/" + fname)
                          .resize((197, 197), Image.ANTIALIAS)) for fname in
                 os.listdir("./Images/train/")]).reshape(-1,197,197,1)
# creating labels
y = np.array([[1,0,0],[0,1,0],[0,0,1]])
# create resnet model
model = ResNet50(input_shape=(197, 197,1),classes=3,weights=None)

# compile & fit model
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['acc'])

model.fit(imgs,y,epochs=5,shuffle=True)

# predict on training data
print(model.predict(imgs))

The model does overfit the data:

3/3 [==============================] - 22s - loss: 1.3229 - acc: 0.0000e+00
Epoch 2/5
3/3 [==============================] - 0s - loss: 0.1474 - acc: 1.0000
Epoch 3/5
3/3 [==============================] - 0s - loss: 0.0057 - acc: 1.0000
Epoch 4/5
3/3 [==============================] - 0s - loss: 0.0107 - acc: 1.0000
Epoch 5/5
3/3 [==============================] - 0s - loss: 1.3815e-04 - acc: 1.0000

but predictions are:

 [[  1.05677405e-08   9.99999642e-01   3.95520459e-07]
 [  1.11955103e-08   9.99999642e-01   4.14905685e-07]
 [  1.02637095e-07   9.99997497e-01   2.43751242e-06]]

which means that all images got label=[0,1,0]

why? and how can that happen?


Solution 1:

It's because of the batch normalization layers.

In training phase, the batch is normalized w.r.t. its mean and variance. However, in testing phase, the batch is normalized w.r.t. the moving average of previously observed mean and variance.

Now this is a problem when the number of observed batches is small (e.g., 5 in your example) because in the BatchNormalization layer, by default moving_mean is initialized to be 0 and moving_variance is initialized to be 1.

Given also that the default momentum is 0.99, you'll need to update the moving averages quite a lot of times before they converge to the "real" mean and variance.

That's why the prediction is wrong in the early stage, but is correct after 1000 epochs.


You can verify it by forcing the BatchNormalization layers to operate in "training mode".

During training, the accuracy is 1 and the loss is close to zero:

model.fit(imgs,y,epochs=5,shuffle=True)
Epoch 1/5
3/3 [==============================] - 19s 6s/step - loss: 1.4624 - acc: 0.3333
Epoch 2/5
3/3 [==============================] - 0s 63ms/step - loss: 0.6051 - acc: 0.6667
Epoch 3/5
3/3 [==============================] - 0s 57ms/step - loss: 0.2168 - acc: 1.0000
Epoch 4/5
3/3 [==============================] - 0s 56ms/step - loss: 1.1921e-07 - acc: 1.0000
Epoch 5/5
3/3 [==============================] - 0s 53ms/step - loss: 1.1921e-07 - acc: 1.0000

Now if we evaluate the model, we'll observe high loss and low accuracy because after 5 updates, the moving averages are still pretty close to the initial values:

model.evaluate(imgs,y)
3/3 [==============================] - 3s 890ms/step
[10.745396614074707, 0.3333333432674408]

However, if we manually specify the "learning phase" variable and let the BatchNormalization layers use the "real" batch mean and variance, the result becomes the same as what's observed in fit().

sample_weights = np.ones(3)
learning_phase = 1  # 1 means "training"
ins = [imgs, y, sample_weights, learning_phase]
model.test_function(ins)
[1.192093e-07, 1.0]

It's also possible to verify it by changing the momentum to a smaller value.

For example, by adding momentum=0.01 to all the batch norm layers in ResNet50, the prediction after 20 epochs is:

model.predict(imgs)
array([[  1.00000000e+00,   1.34882026e-08,   3.92139575e-22],
       [  0.00000000e+00,   1.00000000e+00,   0.00000000e+00],
       [  8.70998792e-06,   5.31159838e-10,   9.99991298e-01]], dtype=float32)