在 spark 中为 minHashLSH 转换数据帧

问题描述

我有这个数据框:

val df = (
    spark
    .createDataFrame(
        Seq((1L,2L),(1L,5L),8L),(2L,4L),6L),8L))
    )
    .toDF("A","B")
    .groupBy("A")
    .agg(collect_list("B").alias("B"))
)

我想将其转换为以下形式:

val dfTransformed = 
(
    spark
    .createDataFrame(
        Seq(
            (1,Vectors.sparse(9,Seq((2,1.0),(5,(8,1.0)))),(2,Seq((4,(6,1.0))))
        )
    ).toDF("A","B")
)

我想这样做以便我可以使用 MinHashLSH 转换 (https://spark.apache.org/docs/2.2.3/api/scala/index.html#org.apache.spark.ml.feature.MinHashLSH)。

我尝试使用 UDF 如下但没有成功:

def f(x:Array[Long]) = Vectors.sparse(9,x.map(p => (p.toInt,1.0)).toSeq)

val udff = udf((x:Array[Long]) => f(x))

val dfTransformed = df.withColumn("transformed",udff(col("B"))).show()

有人可以帮我吗?

解决方法

对 UDF 使用 Seq,而不是 Array

def f(x: Seq[Long]) = Vectors.sparse(9,x.map(p => (p.toInt,1.0)))

val udff = udf((x: Seq[Long]) => f(x))

val dfTransformed = df.withColumn("transformed",udff(col("B")))

dfTransformed.show(false)
+---+---------+-------------------------+
|A  |B        |transformed              |
+---+---------+-------------------------+
|1  |[2,5,8]|(9,[2,8],[1.0,1.0,1.0])|
|2  |[4,6,[4,1.0])|
+---+---------+-------------------------+