问题描述
我使用 Functional API 创建了一个具有三个不同输出层的模型,以测试不同的激活函数。问题是每个 epoch 的输出行太长了。我只想看准确率,而不是损失。
Epoch 1/5
1875/1875 - 4s - loss: 3.7070 - Sigmoid_loss: 1.1836 - softmax_loss: 1.2291 - Softplus_loss: 1.2943 - Sigmoid_accuracy: 0.9021 - softmax_accuracy: 0.9020 - Softplus_accuracy: 0.5787
我不希望 .fit()
函数打印每一层的损失,只打印精度。我搜索了所有 Google 和 Tensorflow 文档,但找不到如何操作。
这是模型的摘要:
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
InputLayer (InputLayer) [(32,784)] 0
__________________________________________________________________________________________________
FirstHidden (Dense) (32,512) 401920 InputLayer[0][0]
__________________________________________________________________________________________________
SecondHidden (Dense) (32,256) 131328 FirstHidden[0][0]
__________________________________________________________________________________________________
Sigmoid (Dense) (32,10) 2570 SecondHidden[0][0]
__________________________________________________________________________________________________
softmax (Dense) (32,10) 2570 SecondHidden[0][0]
__________________________________________________________________________________________________
Softplus (Dense) (32,10) 2570 SecondHidden[0][0]
==================================================================================================
Total params: 540,958
Trainable params: 540,958
Non-trainable params: 0
__________________________________________________________________________________________________
None
谢谢你,祝你有美好的一天。
解决方法
这是我对自定义回调的拍摄。注意我假设 Sigmoid_accuracy、Softmax_accuracy 和 Softplus_accuracy 之前在 model.compile 中定义为度量。 这是自定义回调的代码
class Print_Acc(keras.callbacks.Callback):
def __init__(self):
super(Print_Acc,self).__init__()
def on_epoch_end(self,epoch,logs=None): # method runs on the end of each epoch
sig_acc=logs.get('Sigmoid_accuracy')
softmax_acc =logs.get('Softmax_accuracy')
softplus_acc =logs.get('Softplus_accuracy')
print('For epoch ',' sig acc= ',sig_acc,' softmac acc= ',softmax_acc,' softplus acc= ',softplus_acc)
在 model.fit 中包含 callbacks=[Print_Acc]