如何找到网格邻居x,y 作为整数将它们分组并计算它们在 spark 中的平均值

问题描述

我正在努力寻找一种方法来从如下所示的数据集中计算邻居平均值:

+------+------+---------+
|     X|     Y|  value  |
+------+------+---------+
|     1|     5|   1     |
|     1|     8|   1     |
|     1|     6|   6     |
|     2|     8|   5     |
|     2|     6|   3     |
+------+------+---------+

例如

(1,5) 邻居是 (1,6),(2,6) 所以我需要找到它们所有值的平均值,这里的答案是 (1 + 6 + 3) / 3 = 3.33

(1,8) 个邻居是 (2,8) 并且它们的平均值是 (1 + 5) / 2 = 3

我希望我的解决方案看起来像这样(我只是在此处将坐标连接为键的字符串):

+--------------------------+
|  neighbour_values | mean |
+--------------------------+
| (1,5)_(1,6)_(2,6) | 3.33 |
| (1,8)_(2,8)       | 3    |
+--------------------------+

我已经尝试了列连接,但似乎并没有走多远。 我正在考虑的解决方案之一是迭代 throw table 两次,一次用于元素,一次用于其他值,并检查它是否是邻居。不幸的是,我对 spark 还很陌生,我似乎找不到任何有关如何操作的信息。

非常感谢任何帮助! 谢谢!:))

解决方法

答案取决于您是否只关心相邻邻居的分组。这种情况可能会导致歧义,如果说,有一个大于两个项目的宽度或高度的连续块。因此,下面的方法假设一组连续坐标中的所有项目都被聚到一个组中,并且每个原始记录都属于一个组。

这种将集合划分为不相交坐标的假设适用于 union-find 算法。

由于 union-find 是递归的,这种方法将原始元素收集到内存中并基于这些值创建一个 UDF。请注意,对于大型数据集,这可能会很慢和/或需要大量内存。

// create example DF
val df = Seq((1,5,1),(1,8,6,6),(2,5),3)).toDF("x","y","value")

// collect all coordinates into in-memory collections
val coordinates = df.select("x","y").collect().map(r => (r.getInt(0),r.getInt(1)))
val coordSet = coordinates.toSet

type K = (Int,Int)
val directParent:Map[K,Option[K]] = coordinates.map { case (x: Int,y: Int) =>
  val possibleParents = coordSet.intersect(Set((x - 1,y - 1),(x,(x - 1,y)))
  val parent = if (possibleParents.isEmpty) None else Some(possibleParents.min)
  ((x,y),parent)
}.toMap

// skip unionFind if only concerned with direct neighbors
def unionFind(key: K,map:Map[K,Option[K]]): K = {
  val mapValue = map.get(key)
  mapValue.map(parentOpt => parentOpt match {
    case None => key
    case Some(parent) => unionFind(parent,map)
  }).getOrElse(key)
}

val canonicalUDF = udf((x: Int,y: Int) => unionFind((x,directParent))

// group using the canonical element
// create column "neighbors" based on x,y values in each group
val avgDF = df.groupBy(canonicalUDF($"x",$"y").alias("canonical")).agg(
  concat_ws("_",collect_list(concat(lit("("),$"x",lit(","),$"y",lit(")")))).alias("neighbors"),avg($"value")).drop("canonical")

结果:

avgDF.show(10,false)
+-----------------+------------------+
|neighbors        |avg(value)        |
+-----------------+------------------+
|(1,8)_(2,8)      |3.0               |
|(1,5)_(1,6)_(2,6)|3.3333333333333335|
+-----------------+------------------+