问题描述
我正在尝试为决策树绘制ROC曲线。但是,在计算混淆矩阵时,为每个阈值计算将花费太多时间。因此,有一种更好的方法来计算矩阵。
解决方法
在您的confusion_matrix方法中, 您可以缓存预测以优化性能
def confusion_matrix(predictions):
# Calculate the elements of the confusion matrix
predictions.cache()
TN = predictions.filter('prediction = 0 AND label = 0').count()
TP = predictions.filter('prediction = 1 AND label = 1').count()
FN = predictions.filter('prediction = 0 AND label = 1').count()
FP = predictions.filter('prediction = 1 AND label = 0').count()
predictions.unpersist()
return TP,TN,FP,FN
祝你好运!
,我建议您使用scikit-learn实现,您只需要提供一个包含真实值的数组,另一个提供预测值:sklearn.metrics.roc_curve