如何检查schema.field.dataype是否是带有Scala的Scala中的字符串数组

问题描述

这是源代码

override def createWriter(writeUUID: String,schema: StructType,saveMode: SaveMode,options: DataSourceOptions): Optional[DataSourceWriter] = {
    saveMode match {
      case SaveMode.Append =>
        schema.fields.foreach(field => {
          if (field.dataType.typeName == "array") {

            throw ArrayDataTypeNotSupportedException(s"${field.name} column is ArrayType," +
              "writing arrays to CSV isn't supported. Please convert this column to a different data type.")

          }
        })

        val sparkSession = SparkSession.active
        val hadoopConf = new SerializableConfiguration(sparkSession.sparkContext.hadoopConfiguration)

        val optionsMap = options.asMap()
        val csvOptionsMap = optionsMap.asScala.toMap // convert Java HashMap to Scala Map
        // needed for the univocityGenerator
        val csvOptions = new CSVOptions(
          csvOptionsMap,columnPruning = sparkSession.sessionState.conf.csvColumnPruning,sparkSession.sessionState.conf.sessionLocalTimeZone)


        Optional.of(new KinesisCSVDataSourceWriter(writeUUID,KinesisCSVDataSourceOptions(csvOptionsMap,schema),csvOptions,hadoopConf))

      case _ => throw UnsupportedSaveModeException("Only SaveMode.Append is supported")
    }
  }
}

测试用例是:

test("testArrayInSchema") {
    val df = spark.createDataFrame(Seq(
      TestDataSetArrays(
        Array(1,2,3),Array("a","b","c"),Array(new Timestamp(0),new Timestamp(1),new Timestamp(3))
      )
    ))

    assertThrows[ArrayDataTypeNotSupportedException] {
      writeDataFrame(df)
    }
  }

请帮助我如何检查schema.filed.datatype是否为数组字符串,而不仅仅是数组。

以前不支持数组,但是现在我仅支持字符串数组,并且应该将String数组转换为以逗号分隔的String。

解决方法

试试这个-

 val df = spark.sql("select array('a','b') as arr")
    df.printSchema()
    /**
      * root
      * |-- arr: array (nullable = false)
      * |    |-- element: string (containsNull = false)
      */

    val arr = df.schema("arr")
    println(arr.dataType.isInstanceOf[ArrayType]
      && arr.dataType.asInstanceOf[ArrayType].elementType == StringType )

    /**
      * true
      */

如果要检查所有字段,也可以使用匹配表达式-

 df.schema.fields.foreach(f => f.dataType match {
      case arrayType: ArrayType if arrayType.elementType == StringType => println(s"field $f is of type array<String>")
      case _ => println(s"field $f is of type ${f.dataType}")
    }
    )

    /**
      * field StructField(arr,ArrayType(StringType,false),false) is of type array<String>
      */

根据评论进行更新

array<string>转换为comma saperated string

 val cols = df.schema.map(f => f.dataType match {
      case arrayType: ArrayType if arrayType.elementType == StringType =>
        // convert array<string> to string
        concat_ws(",",col(f.name)).as(f.name)
      case _ => col(f.name)
    })
    df.select(cols: _*)
      .show(false)
    /**
      * +---+
      * |arr|
      * +---+
      * |a,b|
      * +---+
      */