问题描述
+-----------+---------+----------+-----------+
| userID|grouping1| grouping2| features|
+-----------+---------+----------+-----------+
|12462563356| 1| A | [5.0,43.0]|
|12462563701| 2| A | [1.0,8.0]|
|12462563701| 1| B | [2.0,12.0]|
|12462564356| 1| C | [1.0,1.0]|
|12462565487| 3| C | [2.0,3.0]|
|12462565698| 2| D | [1.0,1.0]|
|12462565698| 1| E | [1.0,1.0]|
|12462566081| 2| C | [1.0,2.0]|
|12462566081| 1| D | [1.0,15.0]|
|12462566225| 2| E | [1.0,1.0]|
|12462566225| 1| A | [9.0,85.0]|
|12462566526| 2| C | [1.0,1.0]|
|12462566526| 1| D | [3.0,79.0]|
|12462567006| 2| D |[11.0,15.0]|
|12462567006| 1| B |[10.0,15.0]|
|12462567006| 3| A |[10.0,15.0]|
|12462586595| 2| B | [2.0,42.0]|
|12462586595| 3| D | [2.0,16.0]|
|12462589343| 3| E | [1.0,1.0]|
+-----------+---------+----------+-----------+
对于 grouping2
A
、B
、C
和 D
中的值,我需要应用 UDF_A
、UDF_B
, UDF_C
和 UDF_D
分别。有没有办法可以写一些类似
dataset = dataset.withColumn('outputColName',selectUDF(**params))
其中 selectUDF 定义为
def selectUDF(**params):
if row[grouping2] == A:
return UDF_A(**params)
elif row[grouping2] == B:
return UDF_B(**params)
elif row[grouping2] == C:
return UDF_C(**params)
elif row[grouping2] == D:
return UDF_D(**params)
使用以下示例来说明我正在尝试做的事情 是的,我也是这么想的。我正在使用以下玩具代码来检查这个
>>> df = sc.parallelize([[1,2,3],[2,3,4]]).toDF(("a","b","c"))
>>> df.show()
+---+---+---+
| a| b| c|
+---+---+---+
| 1| 2| 3|
| 2| 3| 4|
+---+---+---+
>>> def udf1(col):
... return col1*col1
...
>>> def udf2(col):
... return col2*col2*col2
...
>>> def select_udf(col1,col2):
... if col1 == 2:
... return udf1(col2)
... elif col1 == 3:
... return udf2(col2)
... else:
... return 0
...
>>> from pyspark.sql.functions import col
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> select_udf = udf(select_udf,IntegerType())
>>> udf1 = udf(udf1,IntegerType())
>>> udf2 = udf(udf2,IntegerType())
>>> df.withColumn("outCol",select_udf(col("b"),col("c"))).show()
[Stage 9:============================================> (3 + 1) / 4]
这似乎永远停留在这个阶段。任何人都可以建议这里可能有什么问题吗?
解决方法
您不需要 selectUDF
,只需使用 when
表达式根据 grouping2
列的值应用所需的 udf:
from pyspark.sql.functions import col,when
df = df.withColumn(
"outCol",when(col("grouping2") == "A",UDF_A(*params))
.when(col("grouping2") == "B",UDF_B(*params))
.when(col("grouping2") == "C",UDF_C(*params))
.when(col("grouping2") == "D",UDF_D(*params))
)