问题描述
我创建了以下图表:
spark = SparkSession.builder.appName('aggregate').getorCreate()
vertices = spark.createDataFrame([('1','foo',99),('2','bar',10),('3','baz',25),('4','spam',7)],['id','name','value'])
edges = spark.createDataFrame([('1','2'),('1','3'),'4')],['src','dst'])
g = GraphFrame(vertices,edges)
我想聚合消息,这样对于任何给定的顶点,我们都有一个包含其子顶点一直到边缘的所有值的列表。例如,从顶点 1
我们有一个子边到顶点 3
,它有一个子边到顶点 4
。我们还有一个到 2
的子边。即:
(1) --> (3) --> (4)
\
\--> (2)
从 1
我想从这个路径收集所有值:[99,10,25,7]
。其中99
是顶点1
的值,10
是子顶点2
的值,25
是顶点3
的值7
是顶点 4
处的值。
从 3
我们将有值 [25,7]
等
我可以用 aggregateMessages
来近似:
agg = g.aggregateMessages(collect_list(AM.msg).alias('allValues'),sendToSrc=AM.dst['value'],sendToDst=None)
agg.show()
产生:
+---+---------+
| id|allValues|
+---+---------+
| 3| [7]|
| 1| [25,10]|
+---+---------+
在 1
处,我们有 [25,10]
,它们是直接子值,但我们缺少 7
和 99
的“self”值。
同样,我缺少顶点 25
的 3
。
我如何“递归地”聚合消息,使得来自子顶点的 allValues
在父顶点聚合?
解决方法
根据您的问题调整 this answer,并整理该答案的结果以获得您想要的输出。我承认这是一个非常丑陋的解决方案,但我希望它对您有所帮助,作为努力实现更高效和优雅实施的起点。
from graphframes import GraphFrame
from graphframes.lib import Pregel
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *
vertices = spark.createDataFrame([('1','foo',99),('2','bar',10),('3','baz',25),('4','spam',7)],['id','name','value'])
edges = spark.createDataFrame([('1','2'),('1','3'),'4')],['src','dst'])
g = GraphFrame(vertices,edges)
### Adapted from previous answer
vertColSchema = StructType()\
.add("dist",DoubleType())\
.add("node",StringType())\
.add("path",ArrayType(StringType(),True))
def vertexProgram(vd,msg):
if msg == None or vd.__getitem__(0) < msg.__getitem__(0):
return (vd.__getitem__(0),vd.__getitem__(1),vd.__getitem__(2))
else:
return (msg.__getitem__(0),msg.__getitem__(2))
vertexProgramUdf = F.udf(vertexProgram,vertColSchema)
def sendMsgToDst(src,dst):
srcDist = src.__getitem__(0)
dstDist = dst.__getitem__(0)
if srcDist < (dstDist - 1):
return (srcDist + 1,src.__getitem__(1),src.__getitem__(2) + [dst.__getitem__(1)])
else:
return None
sendMsgToDstUdf = F.udf(sendMsgToDst,vertColSchema)
def aggMsgs(agg):
shortest_dist = sorted(agg,key=lambda tup: tup[1])[0]
return (shortest_dist.__getitem__(0),shortest_dist.__getitem__(1),shortest_dist.__getitem__(2))
aggMsgsUdf = F.udf(aggMsgs,vertColSchema)
result = (
g.pregel.withVertexColumn(
colName = "vertCol",initialExpr = F.when(
F.col("id") == 1,F.struct(F.lit(0.0),F.col("id"),F.array(F.col("id")))
).otherwise(
F.struct(F.lit(float("inf")),F.array(F.lit("")))
).cast(vertColSchema),updateAfterAggMsgsExpr = vertexProgramUdf(F.col("vertCol"),Pregel.msg())
)
.sendMsgToDst(sendMsgToDstUdf(F.col("src.vertCol"),Pregel.dst("vertCol")))
.aggMsgs(aggMsgsUdf(F.collect_list(Pregel.msg())))
.setMaxIter(3) ## This should be greater than the max depth of the graph
.setCheckpointInterval(1)
.run()
)
df = result.select("vertCol.node","vertCol.path").repartition(1)
df.show()
+----+---------+
|node| path|
+----+---------+
| 1| [1]|
| 2| [1,2]|
| 3| [1,3]|
| 4|[1,3,4]|
+----+---------+
### Wrangling the dataframe to get desired output
final = df.select(
'node',F.posexplode_outer('path')
).withColumn(
'children',F.collect_list('col').over(Window.partitionBy('node').orderBy(F.desc('pos')))
).groupBy('col').agg(
F.array_distinct(F.flatten(F.collect_list('children'))).alias('children')
).alias('t1').repartition(1).join(
vertices,F.array_contains(F.col('t1.children'),vertices.id)
).groupBy('col').agg(
F.collect_list('value').alias('values')
).withColumnRenamed('col','id').orderBy('id')
final.show()
+---+---------------+
| id| values|
+---+---------------+
| 1|[99,10,25,7]|
| 2| [10]|
| 3| [25,7]|
| 4| [7]|
+---+---------------+