根据行值将 UDF 应用于 pyspark 数据帧

问题描述

我有一个具有以下架构的 pyspark 数据框

+-----------+---------+----------+-----------+
|     userID|grouping1| grouping2|   features|
+-----------+---------+----------+-----------+
|12462563356|        1|        A | [5.0,43.0]|
|12462563701|        2|        A |  [1.0,8.0]|
|12462563701|        1|        B | [2.0,12.0]|
|12462564356|        1|        C |  [1.0,1.0]|
|12462565487|        3|        C |  [2.0,3.0]|
|12462565698|        2|        D |  [1.0,1.0]|
|12462565698|        1|        E |  [1.0,1.0]|
|12462566081|        2|        C |  [1.0,2.0]|
|12462566081|        1|        D | [1.0,15.0]|
|12462566225|        2|        E |  [1.0,1.0]|
|12462566225|        1|        A | [9.0,85.0]|
|12462566526|        2|        C |  [1.0,1.0]|
|12462566526|        1|        D | [3.0,79.0]|
|12462567006|        2|        D |[11.0,15.0]|
|12462567006|        1|        B |[10.0,15.0]|
|12462567006|        3|        A |[10.0,15.0]|
|12462586595|        2|        B | [2.0,42.0]|
|12462586595|        3|        D | [2.0,16.0]|
|12462589343|        3|        E |  [1.0,1.0]|
+-----------+---------+----------+-----------+

对于 grouping2 ABCD 中的值,我需要应用 UDF_AUDF_BUDF_CUDF_D 分别。有没有办法可以写一些类似

的东西

dataset = dataset.withColumn('outputColName',selectUDF(**params))

其中 selectUDF 定义为

def selectUDF(**params):
    if row[grouping2] == A:
        return UDF_A(**params)
    elif row[grouping2] == B:
        return UDF_B(**params)
    elif row[grouping2] == C:
        return UDF_C(**params)
    elif row[grouping2] == D:
        return UDF_D(**params)

使用以下示例来说明我正在尝试做的事情 是的,我也是这么想的。我正在使用以下玩具代码来检查这个

>>> df = sc.parallelize([[1,2,3],[2,3,4]]).toDF(("a","b","c"))
>>> df.show()
+---+---+---+
|  a|  b|  c|
+---+---+---+
|  1|  2|  3|
|  2|  3|  4|
+---+---+---+
>>> def udf1(col):
...     return col1*col1
...
>>> def udf2(col):
...     return col2*col2*col2
...
>>> def select_udf(col1,col2):
...     if col1 == 2:
...         return udf1(col2)
...     elif col1 == 3:
...         return udf2(col2)
...     else:
...         return 0
...
>>> from pyspark.sql.functions import col
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> select_udf = udf(select_udf,IntegerType())
>>> udf1 = udf(udf1,IntegerType())
>>> udf2 = udf(udf2,IntegerType())
>>> df.withColumn("outCol",select_udf(col("b"),col("c"))).show()
[Stage 9:============================================>              (3 + 1) / 4]

这似乎永远停留在这个阶段。任何人都可以建议这里可能有什么问题吗?

解决方法

您不需要 selectUDF,只需使用 when 表达式根据 grouping2 列的值应用所需的 udf:

from pyspark.sql.functions import col,when

df = df.withColumn(
    "outCol",when(col("grouping2") == "A",UDF_A(*params))
        .when(col("grouping2") == "B",UDF_B(*params))
        .when(col("grouping2") == "C",UDF_C(*params))
        .when(col("grouping2") == "D",UDF_D(*params))
)