问题描述
我的输入火花数据框是;
Client Feature1 Feature2
1 10 1
1 15 3
1 20 5
1 25 7
1 30 9
2 1 10
2 2 11
2 3 12
2 4 13
2 5 14
3 100 0
3 150 1
3 200 2
3 250 3
3 300 4
我想为每个客户端将 pyspark 数据帧转换为 3d numpy 矩阵。 我根据上面的数据分享了想要的输出;
[[[10,1],[15,3],[20,5],[25,7],[30,9]],[[1,10],[2,11],[3,12],[4,13],[5,14]],[[100,0],[150,[200,2],[250,[300,4]]]
你能帮我解决这个问题吗?
解决方法
您可以在将数据帧收集到 Python 并将结果转换为 Numpy 数组之前进行 collect_list
聚合:
import numpy as np
import pyspark.sql.functions as F
a = np.array([
i[1] for i in
df.groupBy('Client')
.agg(F.collect_list(F.array(*df.columns[1:])))
.orderBy('Client')
.collect()
])
print(a)
array([[[ 10,1],[ 15,3],[ 20,5],[ 25,7],[ 30,9]],[[ 1,10],[ 2,11],[ 3,12],[ 4,13],[ 5,14]],[[100,0],[150,[200,2],[250,[300,4]]])