为什么我会在我的构建中看到 DataFrame 的重复物化?

问题描述

我正在执行以下代码

from pyspark.sql import types as T,functions as F,SparkSession
spark = SparkSession.builder.getorCreate()

schema = T.StructType([
  T.StructField("col_1",T.IntegerType(),False),T.StructField("col_2",T.StructField("measure_1",T.FloatType(),T.StructField("measure_2",])
data = [
  {"col_1": 1,"col_2": 2,"measure_1": 0.5,"measure_2": 1.5},{"col_1": 2,"col_2": 3,"measure_1": 2.5,"measure_2": 3.5}
]

df = spark.createDataFrame(data,schema)
df.show()

"""
+-----+-----+---------+---------+
|col_1|col_2|measure_1|measure_2|
+-----+-----+---------+---------+
|    1|    2|      0.5|      1.5|
|    2|    3|      2.5|      3.5|
+-----+-----+---------+---------+
"""

group_cols = ["col_1","col_2"]
measure_cols = ["measure_1","measure_2"]
for col in measure_cols:
  stats = df.groupBy(group_cols).agg(
    F.max(col).alias("max_" + col),F.avg(col).alias("avg_" + col),)
  df = df.join(stats,group_cols)
df.show()

"""
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|col_1|col_2|measure_1|measure_2|max_measure_1|avg_measure_1|max_measure_2|avg_measure_2|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
"""

如果我最初的 df 不是那么简单,而是实际上是一系列连接或其他操作,那么问题就出现了。当我查看我的工作时,我注意到 df 似乎在我的 groupBy 操作执行时派生了几次。这里的简单查询计划是:


df.explain()
"""
>>> df.explain()
== Physical Plan ==
*(11) Project [col_1#26,col_2#27,measure_1#28,measure_2#29,max_measure_1#56,avg_measure_1#58,max_measure_2#80,avg_measure_2#82]
+- *(11) SortMergeJoin [col_1#26,col_2#27],[col_1#87,col_2#88],Inner
   :- *(5) Project [col_1#26,avg_measure_1#58]
   :  +- *(5) SortMergeJoin [col_1#26,[col_1#63,col_2#64],Inner
   :     :- *(2) Sort [col_1#26 ASC NULLS FirsT,col_2#27 ASC NULLS FirsT],false,0
   :     :  +- Exchange hashpartitioning(col_1#26,200),ENSURE_REQUIREMENTS,[id=#276]
   :     :     +- *(1) Scan ExistinGrdD[col_1#26,measure_2#29]
   :     +- *(4) Sort [col_1#63 ASC NULLS FirsT,col_2#64 ASC NULLS FirsT],0
   :        +- *(4) HashAggregate(keys=[col_1#63,functions=[max(measure_1#65),avg(cast(measure_1#65 as double))])
   :           +- Exchange hashpartitioning(col_1#63,col_2#64,[id=#282]
   :              +- *(3) HashAggregate(keys=[col_1#63,functions=[partial_max(measure_1#65),partial_avg(cast(measure_1#65 as double))])
   :                 +- *(3) Project [col_1#63,measure_1#65]
   :                    +- *(3) Scan ExistinGrdD[col_1#63,measure_1#65,measure_2#66]
   +- *(10) Sort [col_1#87 ASC NULLS FirsT,col_2#88 ASC NULLS FirsT],0
      +- *(10) HashAggregate(keys=[col_1#87,functions=[max(measure_2#90),avg(cast(measure_2#90 as double))])
         +- *(10) HashAggregate(keys=[col_1#87,functions=[partial_max(measure_2#90),partial_avg(cast(measure_2#90 as double))])
            +- *(10) Project [col_1#87,col_2#88,measure_2#90]
               +- *(10) SortMergeJoin [col_1#87,Inner
                  :- *(7) Sort [col_1#87 ASC NULLS FirsT,0
                  :  +- Exchange hashpartitioning(col_1#87,[id=#293]
                  :     +- *(6) Project [col_1#87,measure_2#90]
                  :        +- *(6) Scan ExistinGrdD[col_1#87,measure_1#89,measure_2#90]
                  +- *(9) Sort [col_1#63 ASC NULLS FirsT,0
                     +- *(9) HashAggregate(keys=[col_1#63,functions=[])
                        +- Exchange hashpartitioning(col_1#63,[id=#299]
                           +- *(8) HashAggregate(keys=[col_1#63,functions=[])
                              +- *(8) Project [col_1#63,col_2#64]
                                 +- *(8) Scan ExistinGrdD[col_1#63,measure_2#66]
"""

但是,例如,如果我更改上面的代码,使初始 df 成为连接和联合的结果:

from pyspark.sql import types as T,schema)

right_schema = T.StructType([
  T.StructField("col_1",False)
])
right_data = [
  {"col_1": 1},{"col_1": 1},{"col_1": 2},{"col_1": 2}
]
right_df = spark.createDataFrame(right_data,right_schema)

df = df.unionByName(df)
df = df.join(right_df,on="col_1")
df.show()

"""
+-----+-----+---------+---------+
|col_1|col_2|measure_1|measure_2|
+-----+-----+---------+---------+
|    1|    2|      0.5|      1.5|
|    1|    2|      0.5|      1.5|
|    1|    2|      0.5|      1.5|
|    1|    2|      0.5|      1.5|
|    2|    3|      2.5|      3.5|
|    2|    3|      2.5|      3.5|
|    2|    3|      2.5|      3.5|
|    2|    3|      2.5|      3.5|
+-----+-----+---------+---------+
"""

df.explain()

"""
== Physical Plan ==
*(7) Project [col_1#299,col_2#300,measure_1#301,measure_2#302,col_2#354,measure_1#355,measure_2#356]
+- *(7) SortMergeJoin [col_1#299],[col_1#353],Inner
   :- *(3) Sort [col_1#299 ASC NULLS FirsT],0
   :  +- Exchange hashpartitioning(col_1#299,[id=#595]
   :     +- Union
   :        :- *(1) Scan ExistinGrdD[col_1#299,measure_2#302]
   :        +- *(2) Scan ExistinGrdD[col_1#299,measure_2#302]
   +- *(6) Sort [col_1#353 ASC NULLS FirsT],0
      +- ReusedExchange [col_1#353,measure_2#356],Exchange hashpartitioning(col_1#299,[id=#595]
"""

group_cols = ["col_1",group_cols)
df.show()

"""
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|col_1|col_2|measure_1|measure_2|max_measure_1|avg_measure_1|max_measure_2|avg_measure_2|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
"""

df.explain()

"""
== Physical Plan ==
*(31) Project [col_1#404,col_2#405,measure_1#406,measure_2#407,max_measure_1#465,avg_measure_1#467,max_measure_2#489,avg_measure_2#491]
+- *(31) SortMergeJoin [col_1#404,col_2#405],[col_1#496,col_2#497],Inner
   :- *(15) Project [col_1#404,avg_measure_1#467]
   :  +- *(15) SortMergeJoin [col_1#404,[col_1#472,col_2#473],Inner
   :     :- *(7) Sort [col_1#404 ASC NULLS FirsT,col_2#405 ASC NULLS FirsT],0
   :     :  +- Exchange hashpartitioning(col_1#404,[id=#1508]
   :     :     +- *(6) Project [col_1#404,measure_2#407]
   :     :        +- *(6) SortMergeJoin [col_1#404],[col_1#412],Inner
   :     :           :- *(3) Sort [col_1#404 ASC NULLS FirsT],0
   :     :           :  +- Exchange hashpartitioning(col_1#404,[id=#1494]
   :     :           :     +- Union
   :     :           :        :- *(1) Scan ExistinGrdD[col_1#404,measure_2#407]
   :     :           :        +- *(2) Scan ExistinGrdD[col_1#404,measure_2#407]
   :     :           +- *(5) Sort [col_1#412 ASC NULLS FirsT],0
   :     :              +- Exchange hashpartitioning(col_1#412,[id=#1500]
   :     :                 +- *(4) Scan ExistinGrdD[col_1#412]
   :     +- *(14) Sort [col_1#472 ASC NULLS FirsT,col_2#473 ASC NULLS FirsT],0
   :        +- Exchange hashpartitioning(col_1#472,col_2#473,[id=#1639]
   :           +- *(13) HashAggregate(keys=[col_1#472,functions=[max(measure_1#474),avg(cast(measure_1#474 as double))])
   :              +- *(13) HashAggregate(keys=[col_1#472,functions=[partial_max(measure_1#474),partial_avg(cast(measure_1#474 as double))])
   :                 +- *(13) Project [col_1#472,measure_1#474]
   :                    +- *(13) SortMergeJoin [col_1#472],Inner
   :                       :- *(10) Sort [col_1#472 ASC NULLS FirsT],0
   :                       :  +- Exchange hashpartitioning(col_1#472,[id=#1516]
   :                       :     +- Union
   :                       :        :- *(8) Project [col_1#472,measure_1#474]
   :                       :        :  +- *(8) Scan ExistinGrdD[col_1#472,measure_1#474,measure_2#475]
   :                       :        +- *(9) Project [col_1#472,measure_1#474]
   :                       :           +- *(9) Scan ExistinGrdD[col_1#472,measure_2#475]
   :                       +- *(12) Sort [col_1#412 ASC NULLS FirsT],0
   :                          +- ReusedExchange [col_1#412],Exchange hashpartitioning(col_1#412,[id=#1500]
   +- *(30) Sort [col_1#496 ASC NULLS FirsT,col_2#497 ASC NULLS FirsT],0
      +- *(30) HashAggregate(keys=[col_1#496,functions=[max(measure_2#499),avg(cast(measure_2#499 as double))])
         +- *(30) HashAggregate(keys=[col_1#496,functions=[partial_max(measure_2#499),partial_avg(cast(measure_2#499 as double))])
            +- *(30) Project [col_1#496,col_2#497,measure_2#499]
               +- *(30) SortMergeJoin [col_1#496,Inner
                  :- *(22) Sort [col_1#496 ASC NULLS FirsT,0
                  :  +- Exchange hashpartitioning(col_1#496,[id=#1660]
                  :     +- *(21) Project [col_1#496,measure_2#499]
                  :        +- *(21) SortMergeJoin [col_1#496],Inner
                  :           :- *(18) Sort [col_1#496 ASC NULLS FirsT],0
                  :           :  +- Exchange hashpartitioning(col_1#496,[id=#1544]
                  :           :     +- Union
                  :           :        :- *(16) Project [col_1#496,measure_2#499]
                  :           :        :  +- *(16) Scan ExistinGrdD[col_1#496,measure_1#498,measure_2#499]
                  :           :        +- *(17) Project [col_1#496,measure_2#499]
                  :           :           +- *(17) Scan ExistinGrdD[col_1#496,measure_2#499]
                  :           +- *(20) Sort [col_1#412 ASC NULLS FirsT],0
                  :              +- ReusedExchange [col_1#412],[id=#1500]
                  +- *(29) Sort [col_1#472 ASC NULLS FirsT,0
                     +- Exchange hashpartitioning(col_1#472,[id=#1707]
                        +- *(28) HashAggregate(keys=[col_1#472,functions=[])
                           +- *(28) HashAggregate(keys=[col_1#472,functions=[])
                              +- *(28) Project [col_1#472,col_2#473]
                                 +- *(28) SortMergeJoin [col_1#472],Inner
                                    :- *(25) Sort [col_1#472 ASC NULLS FirsT],0
                                    :  +- Exchange hashpartitioning(col_1#472,[id=#1566]
                                    :     +- Union
                                    :        :- *(23) Project [col_1#472,col_2#473]
                                    :        :  +- *(23) Scan ExistinGrdD[col_1#472,measure_2#475]
                                    :        +- *(24) Project [col_1#472,col_2#473]
                                    :           +- *(24) Scan ExistinGrdD[col_1#472,measure_2#475]
                                    +- *(27) Sort [col_1#412 ASC NULLS FirsT],0
                                       +- ReusedExchange [col_1#412],[id=#1500]
"""

您可以在查询计划中看到 join + union 被多次派生,这反映在我的作业执行报告中,我看到具有相同任务数的阶段一次又一次地运行。

我怎样才能阻止这种重新派生的发生?

解决方法

针对基础 DataFrame 多次加入 + 派生列的转换的内部循环将受益于 PySpark 的 .cache() 函数。这明确指示 Spark 保留派生的 DataFrame 而不是重新计算它。这意味着您将计算初始联合 + 连接一次,然后在后续转换中重新使用 DataFrame。

这是一行添加,将大大有利于您的执行。

from pyspark.sql import types as T,functions as F,SparkSession
spark = SparkSession.builder.getOrCreate()

schema = T.StructType([
  T.StructField("col_1",T.IntegerType(),False),T.StructField("col_2",T.StructField("measure_1",T.FloatType(),T.StructField("measure_2",])
data = [
  {"col_1": 1,"col_2": 2,"measure_1": 0.5,"measure_2": 1.5},{"col_1": 2,"col_2": 3,"measure_1": 2.5,"measure_2": 3.5}
]

df = spark.createDataFrame(data,schema)

right_schema = T.StructType([
  T.StructField("col_1",False)
])
right_data = [
  {"col_1": 1},{"col_1": 1},{"col_1": 2},{"col_1": 2}
]
right_df = spark.createDataFrame(right_data,right_schema)

df = df.unionByName(df)
df = df.join(right_df,on="col_1")

# ========= Added this line BEFORE the loop
df = df.cache()
# =========

group_cols = ["col_1","col_2"]
measure_cols = ["measure_1","measure_2"]
for col in measure_cols:
  stats = df.groupBy(group_cols).agg(
    F.max(col).alias("max_" + col),F.avg(col).alias("avg_" + col),)
  df = df.join(stats,group_cols)
df.show()

"""
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|col_1|col_2|measure_1|measure_2|max_measure_1|avg_measure_1|max_measure_2|avg_measure_2|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
"""

df.explain()
"""
>>> df.explain()
== Physical Plan ==
*(4) Project [col_1#1265,col_2#1266,measure_1#1267,measure_2#1268,max_measure_1#1312,avg_measure_1#1314,max_measure_2#1336,avg_measure_2#1338]
+- *(4) BroadcastHashJoin [col_1#1265,col_2#1266],[col_1#1343,col_2#1344],Inner,BuildRight,false
   :- *(4) Project [col_1#1265,avg_measure_1#1314]
   :  +- *(4) BroadcastHashJoin [col_1#1265,[col_1#1319,col_2#1320],BuildLeft,false
   :     :- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0,int,false] as bigint),32) | (cast(input[1,false] as bigint) & 4294967295))),false),[id=#2439]
   :     :  +- *(1) ColumnarToRow
   :     :     +- InMemoryTableScan [col_1#1265,measure_2#1268]
   :     :           +- InMemoryRelation [col_1#1265,measure_2#1268],StorageLevel(disk,memory,deserialized,1 replicas)
   :     :                 +- *(6) Project [col_1#1265,measure_2#1268]
   :     :                    +- *(6) SortMergeJoin [col_1#1265],[col_1#1273],Inner
   :     :                       :- *(3) Sort [col_1#1265 ASC NULLS FIRST],false,0
   :     :                       :  +- Exchange hashpartitioning(col_1#1265,200),ENSURE_REQUIREMENTS,[id=#2169]
   :     :                       :     +- Union
   :     :                       :        :- *(1) Scan ExistingRDD[col_1#1265,measure_2#1268]
   :     :                       :        +- *(2) Scan ExistingRDD[col_1#1265,measure_2#1268]
   :     :                       +- *(5) Sort [col_1#1273 ASC NULLS FIRST],0
   :     :                          +- Exchange hashpartitioning(col_1#1273,[id=#2175]
   :     :                             +- *(4) Scan ExistingRDD[col_1#1273]
   :     +- *(4) HashAggregate(keys=[col_1#1319,functions=[max(measure_1#1321),avg(cast(measure_1#1321 as double))])
   :        +- *(4) HashAggregate(keys=[col_1#1319,functions=[partial_max(measure_1#1321),partial_avg(cast(measure_1#1321 as double))])
   :           +- *(4) ColumnarToRow
   :              +- InMemoryTableScan [col_1#1319,col_2#1320,measure_1#1321]
   :                    +- InMemoryRelation [col_1#1319,measure_1#1321,measure_2#1322],1 replicas)
   :                          +- *(6) Project [col_1#1265,measure_2#1268]
   :                             +- *(6) SortMergeJoin [col_1#1265],Inner
   :                                :- *(3) Sort [col_1#1265 ASC NULLS FIRST],0
   :                                :  +- Exchange hashpartitioning(col_1#1265,[id=#2169]
   :                                :     +- Union
   :                                :        :- *(1) Scan ExistingRDD[col_1#1265,measure_2#1268]
   :                                :        +- *(2) Scan ExistingRDD[col_1#1265,measure_2#1268]
   :                                +- *(5) Sort [col_1#1273 ASC NULLS FIRST],0
   :                                   +- Exchange hashpartitioning(col_1#1273,[id=#2175]
   :                                      +- *(4) Scan ExistingRDD[col_1#1273]
   +- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0,[id=#2461]
      +- *(3) HashAggregate(keys=[col_1#1343,functions=[max(measure_2#1346),avg(cast(measure_2#1346 as double))])
         +- *(3) HashAggregate(keys=[col_1#1343,functions=[partial_max(measure_2#1346),partial_avg(cast(measure_2#1346 as double))])
            +- *(3) Project [col_1#1343,col_2#1344,measure_2#1346]
               +- *(3) BroadcastHashJoin [col_1#1343,false
                  :- *(3) ColumnarToRow
                  :  +- InMemoryTableScan [col_1#1343,measure_2#1346]
                  :        +- InMemoryRelation [col_1#1343,measure_1#1345,measure_2#1346],1 replicas)
                  :              +- *(6) Project [col_1#1265,measure_2#1268]
                  :                 +- *(6) SortMergeJoin [col_1#1265],Inner
                  :                    :- *(3) Sort [col_1#1265 ASC NULLS FIRST],0
                  :                    :  +- Exchange hashpartitioning(col_1#1265,[id=#2169]
                  :                    :     +- Union
                  :                    :        :- *(1) Scan ExistingRDD[col_1#1265,measure_2#1268]
                  :                    :        +- *(2) Scan ExistingRDD[col_1#1265,measure_2#1268]
                  :                    +- *(5) Sort [col_1#1273 ASC NULLS FIRST],0
                  :                       +- Exchange hashpartitioning(col_1#1273,[id=#2175]
                  :                          +- *(4) Scan ExistingRDD[col_1#1273]
                  +- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0,[id=#2454]
                     +- *(2) HashAggregate(keys=[col_1#1319,functions=[])
                        +- *(2) HashAggregate(keys=[col_1#1319,functions=[])
                           +- *(2) ColumnarToRow
                              +- InMemoryTableScan [col_1#1319,col_2#1320]
                                    +- InMemoryRelation [col_1#1319,1 replicas)
                                          +- *(6) Project [col_1#1265,measure_2#1268]
                                             +- *(6) SortMergeJoin [col_1#1265],Inner
                                                :- *(3) Sort [col_1#1265 ASC NULLS FIRST],0
                                                :  +- Exchange hashpartitioning(col_1#1265,[id=#2169]
                                                :     +- Union
                                                :        :- *(1) Scan ExistingRDD[col_1#1265,measure_2#1268]
                                                :        +- *(2) Scan ExistingRDD[col_1#1265,measure_2#1268]
                                                +- *(5) Sort [col_1#1273 ASC NULLS FIRST],0
                                                   +- Exchange hashpartitioning(col_1#1273,[id=#2175]
                                                      +- *(4) Scan ExistingRDD[col_1#1273]
"""

您现在可以在查询计划中看到使用 InMemoryTableRelation 代替了多次重复的 shuffle,并且您的作业执行也会反映出来。

注意:.cache() 不会更改您的查询计划,也不会截断它,它只会更改您的数据创建和重复使用的方式。