问题描述
您好,我正在尝试使用tf.keras训练一个小模型。在tf 2.2.0中,我正在使用一个生成器,该生成器返回[5,120,32,64,9]和标签[5,1]的序列,并且我正在从tf.keras导入
model.compile(
loss="mse",optimizer=Adam(learning_rate=self.learning_rate),metrics=[Recall(),Precision()],sample_weight_mode="temporal",)
if callbacks is None:
callbacks = []
model.fit(
data.training(),callbacks=callbacks,steps_per_epoch=epoch_size,epochs=epochs,validation_data=data.training(),validation_steps=validation_size,verbose=0,)
另外,我将它们添加到“编译和匹配”部分
Layer (type) Output Shape Param #
=================================================================
input (InputLayer) [(None,None,9)] 0
_________________________________________________________________
conv_lst_m2d_1 (ConvLSTM2D) (None,30,62,20) 20960
_________________________________________________________________
time_distributed_MP_1 (TimeD (None,15,31,20) 0
_________________________________________________________________
time_distributed_BN_1 (TimeD (None,20) 80
_________________________________________________________________
time_distributed_F (Timedist (None,9300) 0
_________________________________________________________________
time_distributed_D1 (Timedis (None,32) 297632
_________________________________________________________________
time_distributed (Timedistri (None,32) 0
_________________________________________________________________
time_distributed_D2 (Timedis (None,24) 792
_________________________________________________________________
time_distributed_1 (Timedist (None,24) 0
_________________________________________________________________
time_distributed_D3 (Timedis (None,16) 400
_________________________________________________________________
time_distributed_2 (Timedist (None,16) 0
_________________________________________________________________
output (Timedistributed) (None,1) 17
=================================================================
(我意识到我正在使用培训作为培训数据和验证数据。我试图在我的代码或TF中查找错误,因为我们在召回和精确wrt验证的结果中得到了奇怪而强烈的变化。它从不收敛,不会产生极端的变化,例如从0-0.8-0.2-0.9-0.4-0.8 ...)
另外,我使用的是生成器,它生成输入和输出的元组,因为它可以“纠正问题”
但是我仍然可以得到精确的结果并召回0.00000
100/100 [==============================)-224秒2秒/步-损耗:0.0371 -召回率:0.0000e + 00-精度:0.0000e + 00-val_loss:0.0331-val_recall:0.0000e + 00-val_precision:0.0000e + 00
有人知道我可以在tf 2.2中使用任何其他技巧来解决该问题吗?
我的神经网络摘要如下:
{{1}}
解决方法
Precision和Recall是二进制分类的固有指标。在TensorFlow和Keras中,当用于多类分类问题时,它们可能会显示这些值(0.0000e + 00)。话虽如此,它们仅在二进制分类而不是多分类的情况下才能正常工作。
您必须清楚地定义问题,无论是分类问题还是回归问题(由于您将mse用作损失,通常在回归问题中使用,但将其用于分类也没有问题)。
看看here如何对这些指标的验证集实施正确的评估:How to get other metrics in Tensorflow 2.0 (not only accuracy)?
这种情况的另一个原因是没有True Positive,在这种情况下Precision = TP /(TP + FP)和Recall = TP /(TP + FN)将得到0(前提是FP和FN与0)