SparkML 交叉验证是否仅适用于“标签”列?

问题描述

当我使用一个数据集运行交叉验证 example 时,该数据集在名为“label”的列中具有标签not,我观察到 Spark 3.1.1 上的 IllegalArgumentException。为什么?

下面的代码修改为将“label”列重命名为“target”,并且回归模型的labelCol已设置为“target”。此代码导致异常,而将所有内容保留在“标签”处工作正常。

from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.feature import HashingTF,Tokenizer
from pyspark.ml.tuning import CrossValidator,ParamGridBuilder

training = spark.createDataFrame([
    (0,"a b c d e spark",1.0),(1,"b d",0.0),(2,"spark f g h",(3,"hadoop mapreduce",(4,"b spark who",(5,"g d a y",(6,"spark fly",(7,"was mapreduce",(8,"e spark program",(9,"a e c l",(10,"spark compile",(11,"hadoop software",0.0)
],["id","text","target"]) # try switching between "target" and "label"

tokenizer = Tokenizer(inputCol="text",outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getoutputCol(),outputCol="features")

lr = LogisticRegression(maxIter=10,labelCol="target") #try switching between "target" and "label"

pipeline = Pipeline(stages=[tokenizer,hashingTF,lr])

paramGrid = ParamGridBuilder() \
    .addGrid(hashingTF.numFeatures,[10,100,1000]) \
    .addGrid(lr.regParam,[0.1,0.01]) \
    .build()

crossval = CrossValidator(estimator=pipeline,estimatorParamMaps=paramGrid,evaluator=BinaryClassificationEvaluator(),numFolds=2)  


cvModel = crossval.fit(training)

这是否是预期的行为?

解决方法

您还需要向 BinaryClassificationEvaluator 提供标签列。所以如果你更换线

evaluator=BinaryClassificationEvaluator(),

evaluator=BinaryClassificationEvaluator(labelCol="target"),

它应该可以正常工作。

您可以在 docs 中找到用法。