SGD 分类器 Precision-Recall 曲线

问题描述

我正在研究一个二元分类问题,我有一个像这样的 sgd 分类器:

sgd = SGDClassifier(
    max_iter            = 1000,tol                 = 1e-3,validation_fraction = 0.2,class_weight = {0:0.5,1:8.99}
)

我将它安装在我的训练集上并绘制了精确召回曲线:

from sklearn.metrics import plot_precision_recall_curve
disp = plot_precision_recall_curve(sgd,X_test,y_test)

enter image description here

鉴于 scikit-learn 中的 sgd 分类器默认使用 loss="hinge",如何绘制这条曲线?我的理解是 sgd 的输出不是概率性的——它要么是 1/0。因此没有“阈值”,但 sklearn 精确召回曲线绘制了具有不同类型阈值的锯齿形图。这是怎么回事?

解决方法

您描述的情况实际上与在 documentation example 中找到的情况相同,使用虹膜数据的前 2 类和 LinearSVC 分类器(该算法使用平方铰链损失,就像您的铰链损失在这里使用,导致分类器只产生二元结果而不是概率结果)。结果图是:

enter image description here

即在性质上与您这里的相似。

尽管如此,您的问题是合理的,而且确实是一个不错的问题;当我们的分类器确实不产生概率预测(因此阈值的任何概念听起来无关紧要)时,我们怎么会得到类似于概率分类器产生的行为?

要了解为什么会这样,我们需要深入研究 scikit-learn 源代码,从此处使用的 plot_precision_recall_curve 函数开始,然后沿着线程向下进入兔子洞...

plot_precision_recall_curvesource code 开始,我们发现:

y_pred,pos_label = _get_response(
    X,estimator,response_method,pos_label=pos_label)

因此,为了绘制 PR 曲线,预测 y_pred 不是直接由我们的分类器的 predict 方法产生,而是由 {{1 }} scikit-learn 的内部函数。

_get_response() 依次包含以下行:

_get_response()

最终将我们引向 prediction_method = _check_classifier_response_method( estimator,response_method) y_pred = prediction_method(X) 内部函数;您可以查看完整的 source code - 此处有趣的是 _check_classifier_response_method() 语句之后的以下 3 lines

else

现在,您可能已经开始明白了:在幕后,predict_proba = getattr(estimator,'predict_proba',None) decision_function = getattr(estimator,'decision_function',None) prediction_method = predict_proba or decision_function 检查所使用的分类器是否可以使用 plot_precision_recall_curvepredict_proba() 方法;并且如果 decision_function() not 可用,就像您这里的具有铰链损失的 SGDClassifier(或具有平方铰链损失的 LinearSVC 分类器的 documentation example)一样,它会恢复为predict_proba() 方法,以便计算 decision_function(),随后将用于绘制 PR(和 ROC)曲线。


上述内容可以说已经回答了您的编程问题,即在这种情况下 scikit-learn 如何准确生成绘图和基础计算;关于是否以及为什么使用非概率分类器的 y_pred 确实是获得 PR(或 ROC)曲线的正确和合法方法的进一步理论探究超出了 SO 的范围,它们应该被发送至 { {3}},如有必要。

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...