问题描述
为了实现我的要求“使用提供的外部库处理提供的数据”,我使用 spark-scala 编写了一个 UDAF,它运行良好,直到出现如下场景:
TestwindowFunc.scala
import org.apache.spark.sql.SparkSession
object TestwindowFunc {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("TestwindowFunc")
.master("local[3]")
.config("spark.driver.memory","5g")
.getorCreate()
spark.udf.register("custAvg",new CustAvg)
val df = spark.read.option("delimiter","|").option("header","true")
.csv("./src/main/resources/students_mark.csv")
df.createOrReplaceTempView("testwindowFunc")
val df1 = spark.sql("select X.*" +
",custAvg(ACT_MARK,OUT_OF) over (partition by STUDENT_ID order by ACT_MARK) a" +
",OUT_OF) over (partition by STUDENT_ID order by ACT_MARK) b" +
" from testwindowFunc X")
df1.show()
}
}
CustAvg.scala
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer,UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType,IntegerType,LongType,StructField,StructType}
class CustAvg extends UserDefinedAggregateFunction {
var initializeCounter = 0
var updateCounter = 0
override def inputSchema: StructType = StructType(Array(
StructField("act_mark",IntegerType),StructField("out_of",IntegerType)
)
)
override def bufferSchema: StructType = StructType(Array(
StructField("act_mark_tot",LongType),StructField("out_of_tot",LongType)
))
override def dataType: DataType = LongType
override def deterministic: Boolean = false
override def initialize(buffer: MutableAggregationBuffer): Unit = {
initializeCounter += 1
println("initialize:::" + initializeCounter)
updateCounter = 0
/**
* initializing the external library for each window
*/
// uncomment the below lines to execute the function
// buffer(0) = 0L
// buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer,input: Row): Unit = {
updateCounter += 1
println("update:::" + updateCounter)
/**
* sending data to the external library for each row of the respective window
*/
// uncomment the below lines to execute the function
// buffer(0) = buffer.getLong(0) + input.getInt(0)
// buffer(1) = buffer.getLong(1) + input.getInt(1)
}
override def merge(buffer1: MutableAggregationBuffer,buffer2: Row): Unit = {
throw new Exception("Merge Not Allowed")
}
override def evaluate(buffer: Row): Any = {
println("evaluate:::" + updateCounter)
/**
* calling the external library to process the data
*/
// uncomment the below line to execute the function
// buffer.getLong(0)
}
}
students_mark.csv
STUDENT_ID|ACT_MARK|OUT_OF
1|70|100
1|68|100
1|90|100
预期输出
initialize:::1
update:::1
evaluate:::1
update:::2
evaluate:::2
update:::3
evaluate:::3
initialize:::2
update:::1
evaluate:::1
update:::2
evaluate:::2
update:::3
evaluate:::3
实际输出
initialize:::1
initialize:::2
update:::1
update:::2
evaluate:::2
evaluate:::2
update:::3
update:::4
evaluate:::4
evaluate:::4
update:::5
update:::6
evaluate:::6
evaluate:::6
这是 spark 在这种情况下的行为方式还是我在这里做错了什么?
有人可以用最恰当的解释帮助我解决这个问题。
版本详情:
- scala:2.11
- 火花:2.4.0
提前致谢。
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)