PySpark RDD:列数不匹配

问题描述

我想使用pyspark建立一个数据框,其中一个列是数据集的另外两个列的SipHash的结果。为此,我创建了一个rdd.map()函数调用函数,如下所示:

import siphash
from pyspark.sql import Row
from pyspark.sql import sqlContext
from pyspark.sql.types import *

sqlContext = sqlContext( spark )

# Hashing function
def hash_two_columns( row ):
    # Transform row to a dict
    row_dict = row.asDict()
    # Concat col1 and col2
    concat_str = 'E'.join( [str(row_dict['col1']),str(row_dict['col2'])] )
    # Fill string with 0 to get 16 bytes (otherwise error is raised)
    sixteenBytes_str = concat_str.zfill(16)
    # Preserve concatenated value for testing (this can be removed later)
    row_dict["hashcols_str"] = sixteenBytes_str
    # Calculate siphash
    row_dict["hashcols_id"] = siphash.SipHash_2_4( sixteenBytes_str.encode('utf-8') ).hash()
    return Row( **row_dict )

# Create test dataframe
test_df = spark.createDataFrame([
         (1,"text1",58965,11111),(3,"text2",78652,888888),(4,"text3",],("id","item","col1","col2"))

# Build the schema 
# Using this to avoid "ValueError: Some of types cannot be determined by the first 100 rows" when pyspark tries to deduct schema by itself
test_df_schema = StructType([
    StructField("id",IntegerType(),True),StructField("item",StringType(),StructField("col1",StructField("col2",StructField("hashcols_str",StructField("hashcols_id",LongType(),True)
])

# Create the final Dataframe
final_test_df = sqlContext \
     .createDataFrame(
          test_df.rdd.map(hash_two_columns).collect(),test_df_schema) \
     .toDF()

final_test_df.show(truncate=False)

尽管架构定义与最终的数据框结构匹配,但是运行此代码失败,并出现以下错误

IllegalArgumentException:要求失败:列数不匹配。旧列名称(6):id,item,col1,col2,hashcols_str,hashcols_id新列名称(0):(java.lang.RuntimeException)

有人对如何正确实施此方法有任何想法吗?非常感谢您的支持

解决方法

我找到了基于this post的解决方案:

以这种方式更新功能:

def hash_two_columns( col1,col2 ):
    # Concat col1 and col2
    concat_str = 'E'.join( [col1,col2] )
    # Fill string with 0 to get 16 bytes (otherwise error is raised)
    sixteenBytes_str = concat_str.zfill(16)
    # Calculate siphash
    hashcols_id = siphash.SipHash_2_4( sixteenBytes_str.encode('utf-8') ).hash()
    return hashcols_id

然后,使用withColumn功能,使用UDF(用户定义函数)将新列添加到数据框中。

from pyspark.sql.functions import udf

example_udf = udf( hash_two_columns,LongType() )

test_df = test_df \
    .withColumn( "hashcols_id",example_udf( test_df.col1,test_df.col2 ) )

test_df.show()