将 Spark2.2 的 UDAF 转换为 3.0 Aggregator

问题描述

我已经使用 Spark2.4 在 Scala 中编写了 UDAF。由于我们的 Databricks 集群在 6.4 运行时不再支持,我们需要迁移到 7.3 LTS,它具有长期支持并使用 Spark3。 UDAF 在 Spark3 中已弃用,将来(很可能)将被删除。所以我试图将 UDAF 转换为聚合器函数

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer,UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{IntegerType,StringType,StructField,StructType,DataType}

object MaxCampaignIdAggregator extends UserDefinedAggregateFunction with java.io.Serializable{
  
  override def inputSchema: StructType = new StructType()
    .add("id",IntegerType,true)
    .add("name",true)

  def bufferSchema: StructType =  new StructType()
    .add("id",true)

  // Returned Data Type .
  def dataType: DataType =  new StructType()
    .add("id",true)

  // Self-explaining
  def deterministic: Boolean = true

  // This function is called whenever key changes
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = null
    buffer(1) = null
  }

  // Iterate over each entry of a group
  def update(buffer: MutableAggregationBuffer,inputRow: Row): Unit ={
      
      val inputId = inputRow.getAs[Int](0)
      val actualInputId = inputRow.get(0)
      val inputName = inputRow.getString(1)
      
      val bufferId = buffer.getAs[Int](0)
      val actualBufferId = buffer.get(0)
      val bufferName = buffer.getString(1)
      
      if(actualBufferId == null){
        buffer(0) = actualInputId
        buffer(1) = inputName
      }else if(actualInputId != null) {
        if(inputId > bufferId){
          buffer(0) = inputId
          buffer(1) = inputName
        }
      }  
  }

  // Merge two partial aggregates
  def merge(buffer1: MutableAggregationBuffer,buffer2: Row) = {
    
      val buffer1Id = buffer1.getAs[Int](0)
      val actualbuffer1Id = buffer1.get(0)
      val buffer1Name = buffer1.getString(1)
      
      val buffer2Id = buffer2.getAs[Int](0)
      val actualbuffer2Id = buffer2.get(0)
      val buffer2Name = buffer2.getString(1)
      
     if(actualbuffer1Id == null){
        buffer1(0) = actualbuffer2Id
        buffer1(1) = buffer2Name
     }else if(actualbuffer2Id != null){
        if(buffer2Id > buffer1Id){
          buffer1(0) = buffer2Id
          buffer1(1) = buffer2Name
        }
      }
    
  }

  // Called after all the entries are exhausted.
  def evaluate(buffer: Row): Any = {
    Row(buffer.get(0),buffer.getString(1))
  }
}

使用后输出如下:

{"id": 1282,"name": "McCormick Christmas"}

{"id": 1305,"name": "McCormick Perfect Pinch"}

{"id": 1677,"name": "Viking Cruises Viking Cruises"}

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)