Spark mapInPandas 中有多少个迭代器?

问题描述

我想了解“mapInPandas”在 Spark 中的工作原理。 Databricks 博客上引用的示例是:

from typing import Iterator
import pandas as pd

df = spark.createDataFrame([(1,21),(2,30)],("id","age"))

def pandas_filter(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    for pdf in iterator:
        yield pdf[pdf.id == 1]

df.mapInPandas(pandas_filter,schema=df.schema).show()

问题是,迭代器中将有多少个“pdf”? 我猜也许他们会和分区的数量一样多 但是当我进一步测试代码时,它们似乎太多了(在具有 ~100 m 记录的不同数据集上)

那么有没有办法知道迭代次数是如何确定的? 有没有办法让它等于分区数?

解决方法

您可以在 documentation 中找到:

Spark 中的数据分区被转换为 Arrow 记录批次,这会暂时导致 JVM 中的高内存使用率。为了避免可能出现的内存不足异常,可以通过将 conf “spark.sql.execution.arrow.maxRecordsPerBatch” 设置为一个整数来调整 Arrow 记录批次的大小,该整数将确定最大数量每个批次的行。默认值为每批次 10,000 条记录。如果列数较大,则应相应调整该值。使用此限制,每个数据分区将分成 1 个或多个记录批次进行处理

因此,如果您有 10M 条记录,那么您将拥有大约 10,000 个迭代器