Scala / Spark-将Word2vec输出转换为数据集[_]

问题描述

我认为案例类类型应与DataFrame匹配。但是,我很困惑text列的case类类型是什么?

我的下面的代码

case class vectorData(value: Array[String],vectors: Array[Float])
def main(args: Array[String]) {
    val word2vec = new Word2Vec()
        .setInputCol("value").setoutputCol("vectors")
        .setVectorSize(5).setMinCount(0).setwindowSize(5)
    val dataset = spark.createDataset(data)

    val model = word2vec.fit(dataset)


    val encoder = org.apache.spark.sql.Encoders.product[vectorData]
    val result = model.transform(dataset)

    result.foreach(row => println(row.get(0)))
    println("###################################")
    result.foreach(row => println(row.get(1)))


    val output  = result.as(encoder)
}

如图所示,当我打印第一列时,我得到了:

WrappedArray(@marykatherine_q,kNow!,I,heard,afternoon,wondered,thing.,Moscow,times)
WrappedArray(laying,bed,voice..)
WrappedArray(I'm,sooo,sad!!!,killed,Kutner,House,whyyyyyyyy)

当我打印第二列时,我得到了:

[-0.0495405454809467,0.03403271486361821,0.011959535030958552,-0.008446224654714266,0.0014322120696306229]
[-0.06924172700382769,0.02562551060691476,0.01857258938252926,-0.0269106051127892,-0.011274430900812149]
[-0.06266747579416808,0.007715661790879334,0.047578315007472956,-0.02747830021989477,-0.015755867421188775]

我得到的错误

Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve '`text`' given input columns: [result,value];

似乎我的case类的类型与实际结果不匹配。正确的应该是什么?我希望val outputDataSet[_]

谢谢

编辑:

我已经将case类的列名修改为与word2vec输出相同。现在我收到此错误

Exception in thread "main" org.apache.spark.sql.AnalysisException: need an array field but got struct<type:tinyint,size:int,indices:array<int>,values:array<double>>;

解决方法

据我所知,这只是属性命名的问题。火花告诉您的是,它无法在数据帧text中找到属性result

您没有说明如何创建data对象,但是它必须具有属性value,因为Word2vec可以找到它。 model.transform只需向该数据集添加一个result列,然后将其转换为以下类型的数据框:

root
 |-- value: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- vector: array (nullable = true)
 |    |-- element: float (containsNull = false)
 |-- result: vector (nullable = true)

因此,当您尝试将其转换为数据集时,spark会尝试找到text列并引发该异常。只需重命名value列即可使用:

val output = result.withColumnRenamed("value","text").as(encoder)
,

检查了word2vec的源代码之后,我设法意识到transform的输出实际上不是Array [Float],实际上是Vector(来自oasml.linalg)。

通过更改案例类如下进行工作:

case class vectorData(value: Array[String],vectors: Vector)