如何在 sparkml 分类中指定“正类”?

问题描述

如何在sparkml(二进制)分类中指定“正类”? (或者:MulticlassClassificationEvaluator 如何确定哪个类是“正”类?)

假设我们正在训练一个模型以在二元分类问题中以精度为目标,例如......

label_idxer = StringIndexer(inputCol="response",outputCol="label").fit(df_spark)
# we fit so we can get the "labels" attribute to inform reconversion stage

feature_idxer = StringIndexer(inputCols=cat_features,outputCols=[f"{f}_IDX" for f in cat_features],handleInvalid="keep")

onehotencoder = OneHotEncoder(inputCols=feature_idxer.getoutputCols(),outputCols=[f"{f}_OHE" for f in feature_idxer.getoutputCols()])

assembler = VectorAssembler(inputCols=(num_features + onehotencoder.getoutputCols()),outputCol="features")

rf = RandomForestClassifier(labelCol=label_idxer.getoutputCol(),featuresCol=assembler.getoutputCol(),seed=123456789)

label_converter = IndexToString(inputCol=rf.getPredictionCol(),outputCol="prediction_label",labels=label_idxer.labels)

pipeline = Pipeline(stages=[label_idxer,feature_idxer,onehotencoder,assembler,rf,label_converter])  # type: pyspark.ml.pipeline.PipelineModel

crossval = CrossValidator(estimator=pipeline,evaluator=MulticlassClassificationEvaluator(
                              labelCol=rf.getLabelCol(),predictionCol=rf.getPredictionCol(),metricName="weightedPrecision"),numFolds=3)

(train_u,test_u) = dff.randomSplit([0.8,0.2])
model = crossval.fit(train_u)

我知道...

Precision = TP / (TP + FP) 

...但是您如何将特定类标签指定为 Precision 的目标“正类”? (就目前而言,IDK 在训练中实际使用了哪个响应值,也不知道如何判断)。

解决方法

来自关于 spark 邮件列表的讨论...

按照约定,正类为“1”,负类为“0”;我认为你不能改变它(尽管你可以在需要时翻译你的数据)。 F1 仅在多类评估中以一对一的方式定义。您可以设置“metricLabel”来定义多类中哪个类是“正” - 其他一切都是“负”。

请注意,这意味着(没有在 MulticlassEvaluator 中设置 metricLabel)StringIndexer(特别是 stringOrderType 参数 https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.StringIndexer.html?highlight=stringindexer#pyspark.ml.feature.StringIndexer.stringOrderType ) 将是用户了解他们所说的是他们的正面/负面类别的地方。 (请注意,根据文档,默认值是 frequencyDesc。如果在 frequencyDesc/Asc 下的频率相同,字符串将按字母顺序进一步排序(即,在少数正类的情况下,您会没事的需要命名遵循 0=neg 1=pos 约定))。

在多班级中,没有“正面”班级,它们都只是班级。它在那里默认为 0,但 0 没有任何特殊含义。 您可以将其应用于二进制类设置。在这种情况下,您可以简单地为标签 0 请求 F1,这将计算“0-vs-rest”的 F1,这就像将 0 视为 F1 的“正”类一样。

关于这种解释的一个问题是,BinaryClassificationEvaluator 似乎没有能力评估 Fbeta、Recall、Precision 等(https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.evaluation.BinaryClassificationEvaluator.html?highlight=binaryclassificationevaluator#pyspark.ml.evaluation.BinaryClassificationEvaluator.metricName),而 MulticlassClassificationEvaluator 有(https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.evaluation.MulticlassClassificationEvaluator.html?highlight=classificationevaluator#pyspark.ml.evaluation.MulticlassClassificationEvaluator.metricName ),这意味着如果用户想尝试训练模型以针对 AreaUnderROC 或 F1,则需要在两者之间切换,在二元分类的情况下,这意味着他们需要将正类的索引值从 1 切换(在二元分类中,因为你说 1 是传统的正类)到 0(对于多类评估器,因为文档说默认的 metricLabel 是 0)。