利用PySpark,确定一个数组列中的多少个元素包含在另一列中的数组中

问题描述

我的数据集如下:

+--------------------+--------------------+
|                col1|                col2|
+--------------------+--------------------+
|[[563],[242,178]] |          [563,178]|
|[[563],178]] |     [563,178,242]|
|[[563],242,178]] |     [242,563]|
+--------------------+--------------------+

我想做的是确定 col1 中包含 col2 中的按顺序的值。 col1 中的顺序仅在顶级阵列上起作用,而在较低级别的阵列上无关紧要。

例如,上述数据框的输出应为:

+--------------------+--------------------|------+
|                col1|                col2|Output+
+--------------------+--------------------+------+
|[[563],178]|     2+
|[[563],242]|     3+
|[[563],178]|     3+
|[[563],563]|     2+
+--------------------+--------------------+------+

我相当确定这需要UDF,但是我在如何遍历col1中的子数组方面苦苦挣扎。

任何帮助将不胜感激!

喷枪

解决方法

spark-2。4 中使用 array_intersect 函数和flatten函数,然后使用{ {1}}功能。

size

Example: