tf keras的SparseCategoricalCrossentropy和sparse_categorical_accuracy在训练期间报告错误的值

问题描述

这是tf 2.3.0。在训练期间,SparseCategoricalCrossentropy损失和sparse_categorical_accuracy的报告值似乎相去甚远。我查看了我的代码,但找不到任何错误。这是要复制的代码

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

x = np.random.randint(0,255,size=(64,224,3)).astype('float32')
y = np.random.randint(0,3,(64,1)).astype('int32')

ds = tf.data.Dataset.from_tensor_slices((x,y)).batch(32)

def create_model():
  input_layer = tf.keras.layers.Input(shape=(224,3),name='img_input')
  x = tf.keras.layers.experimental.preprocessing.Rescaling(1./255,name='rescale_1_over_255')(input_layer)

  base_model = tf.keras.applications.resnet50(input_tensor=x,weights='imagenet',include_top=False)

  x = tf.keras.layers.GlobalAveragePooling2D(name='global_avg_pool_2d')(base_model.output)

  output = Dense(3,activation='softmax',name='predictions')(x)

  return tf.keras.models.Model(inputs=input_layer,outputs=output)

model = create_model()

model.compile(
  optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=['sparse_categorical_accuracy']
)

model.fit(ds,steps_per_epoch=2,epochs=5)

这是打印的内容

Epoch 1/5
2/2 [==============================] - 0s 91ms/step - loss: 1.5160 - sparse_categorical_accuracy: 0.2969
Epoch 2/5
2/2 [==============================] - 0s 85ms/step - loss: 0.0892 - sparse_categorical_accuracy: 1.0000
Epoch 3/5
2/2 [==============================] - 0s 84ms/step - loss: 0.0230 - sparse_categorical_accuracy: 1.0000
Epoch 4/5
2/2 [==============================] - 0s 82ms/step - loss: 0.0109 - sparse_categorical_accuracy: 1.0000
Epoch 5/5
2/2 [==============================] - 0s 82ms/step - loss: 0.0065 - sparse_categorical_accuracy: 1.0000

但是,如果我再次检查model.evaluate并“手动”检查准确性,则:

model.evaluate(ds)

2/2 [==============================] - 0s 25ms/step - loss: 1.2681 - sparse_categorical_accuracy: 0.2188
[1.268101453781128,0.21875]

y_pred = model.predict(ds)
y_pred = np.argmax(y_pred,axis=-1)
y_pred = y_pred.reshape(-1,1)
np.sum(y == y_pred)/len(y)

0.21875

model.evaluate(...)的结果通过“手动”检查在指标上达成一致。但是,如果您凝视着训练带来的损失/指标,它们看起来就遥不可及。由于从未抛出任何错误或异常,因此很难发现出了什么问题。

此外,我创建了一个非常简单的案例来尝试重现此问题,但实际上在这里无法重现。请注意,batch_size ==数据的长度,因此这不是迷你批次GD,而是完整批次GD(以消除与迷你批次损耗/指标的混淆:

x = np.random.randn(1024,1).astype('float32')
y = np.random.randint(0,(1024,1)).astype('int32')
ds = tf.data.Dataset.from_tensor_slices((x,y)).batch(1024)
model = Sequential()
model.add(Dense(3,activation='softmax'))
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),metrics=['sparse_categorical_accuracy']
)
model.fit(ds,epochs=5)
model.evaluate(ds)

正如我的评论中所述,一个可疑的对象是批处理规范层,对于无法复制的情况,我没有。

解决方法

您将获得不同的结果,因为fit()将训练损失显示为当前时期内每批训练数据损失的平均值。这可以降低时代平均值。并且所计算的损失被进一步用于更新模型。鉴于在训练结束时使用模型直接计算了validate(),导致了不同的损失。您可以查看官方Keras FAQ和相关的StackOverflow post

此外,尝试提高学习率。

,

可以通过模型中batch norm的存在来解释(或至少部分如此)来解释度量标准中的巨大差异。如果引入批处理规范,将呈现2种情况,一种情况是不可复制的,另一种情况是可复制的。在这两种情况下,batch_size都等于数据的全长(也就是没有“随机”的全梯度下降),以最大程度地减少对小批量统计数据的混淆。

不可复制:

  x = np.random.randn(1024,1).astype('float32')
  y = np.random.randint(0,3,(1024,1)).astype('int32')
  ds = tf.data.Dataset.from_tensor_slices((x,y)).batch(1024)

  model = Sequential()
  model.add(Dense(10,activation='relu'))
  model.add(Dense(10,activation='relu'))
  model.add(Dense(3,activation='softmax'))

可重现:

  model = Sequential()
  model.add(Dense(10))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  model.add(Dense(10))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  model.add(Dense(10))
  model.add(BatchNormalization())
  model.add(Activation('relu'))

  model.add(Dense(3,activation='softmax'))

实际上,您可以尝试使用model.predict(x),model(x,training = True),并且y_pred会有很大的不同。此外,根据keras文档,此结果还取决于批处理中的内容。因此,针对x [0]的预测模型(x [0:1],training = True)与模型(x [0:2],training = True)会包含额外的样本。

可能最好转到Keras文档和原始论文,但我确实认为您将不得不接受这一点并相应地解释在进度栏中看到的内容。如果您尝试使用训练损失/准确性来查看是否存在偏差(而不是方差)问题,那么它看起来就很混乱。如有疑问,我认为我们可以对火车进行评估,以确保您的模型何时“收敛”到一个极小的最小值。我在先前的工作中一并忽略了这个细节,因为深网很少出现cos拟合不足(偏差),因此我通过验证损失/指标来确定何时停止训练。但是我可能会回到相同的模型并在训练集上进行评估(只是看模型是否具有能力(没有偏差)。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...