scala – 如何计算Spark中12个月内每个客户滑动1个月的订单总和

我是 Scala的新手.目前我正在尝试在每月下滑的12个月期间汇总火花中的订单数据.

下面是我的数据的简单示例,我尝试对其进行格式化,以便您可以轻松地对其进行测试

import spark.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._


var sample = Seq(("C1","01/01/2016",20),("C1","02/01/2016",5),"03/01/2016",2),"04/01/2016",3),"05/01/2017","08/01/2017","01/02/2017",10),"01/03/2017",10)).toDF("id","order_date","orders")

sample = sample.withColumn("order_date",to_date(unix_timestamp($"order_date","dd/MM/yyyy").cast("timestamp")))

sample.show
+---+----------+------+
 | id|order_date|orders|
 +---+----------+------+
 | C1|2016-01-01|    20|
 | C1|2016-01-02|     5|
 | C1|2016-01-03|     2|
 | C1|2016-01-04|     3|
 | C1|2017-01-05|     5|
 | C1|2017-01-08|     5|
 | C1|2017-02-01|    10|
 | C1|2017-02-01|    10|
 | C1|2017-03-01|    10|
 +---+----------+------+

强加给我的结果如下.

id      period_start    period_end  rolling
C1      2015-01-01      2016-01-01  30
C1      2016-01-01      2017-01-01  40
C1      2016-02-01      2017-02-01  30
C1      2016-03-01      2017-03-01  40

到目前为止我试图做的事情

我将每个客户的日期折叠到了该月的第一天

(e.i. 2016-01-[1..31] >> 2016-01-01 )

import org.joda.time._

val collapse_month = (month:Integer,year:Integer ) => {
   var  dt = new DateTime().withYear(year)
                        .withMonthOfYear(month)
                        .withDayOfMonth(1)
   dt.toString("yyyy-MM-dd")
 }

val collapse_month_udf = udf(collapse_month)


sample = sample.withColumn("period_end",collapse_month_udf(
           month(col("order_date")),year(col("order_date"))
           ).as("date"))

sample.groupBy($"id",$"period_end")
              .agg(sum($"orders").as("orders"))
              .orderBy("period_end").show
+---+----------+------+
 | id|period_end|orders|
 +---+----------+------+
 | C1|2016-01-01|    30|
 | C1|2017-01-01|    10|
 | C1|2017-02-01|    20|
 | C1|2017-03-01|    10|
 +---+----------+------+

我尝试了提供的窗口功能,但我无法使用12个月滑动一个选项.

我真的不确定从这一点开始的最佳方法是什么,考虑到我需要处理多少数据,这不会花费5个小时.

任何帮助,将不胜感激.

解决方法

tried the provided window function but I was not able to use 12 months sliding by one option.

您仍然可以使用间隔较长的窗口,但所有参数都必须以天或周表示:

window($"order_date","365 days","28 days")

不幸的是,这个窗口不会尊重月份或年份的界限,因此对您来说不会有用.

就个人而言,我会先汇总数据:

val byMonth = sample
  .groupBy($"id",trunc($"order_date","month").alias("order_month"))
  .agg(sum($"orders").alias("orders"))
+---+-----------+-----------+                                                   
| id|order_month|sum(orders)|
+---+-----------+-----------+
| C1| 2017-01-01|         10|
| C1| 2016-01-01|         30|
| C1| 2017-02-01|         20|
| C1| 2017-03-01|         10|
+---+-----------+-----------+

创建参考日期范围:

import java.time.temporal.ChronoUnit

val Row(start: java.sql.Date,end: java.sql.Date) = byMonth
  .select(min($"order_month"),max($"order_month"))
  .first

val months = (0L to ChronoUnit.MONTHS.between(
    start.toLocalDate,end.toLocalDate))
  .map(i => java.sql.Date.valueOf(start.toLocalDate.plusMonths(i)))
  .toDF("order_month")

并结合独特的ID:

val ref = byMonth.select($"id").distinct.crossJoin(months)

并与来源联系:

val expanded = ref.join(byMonth,Seq("id","order_month"),"leftouter")
+---+-----------+------+ 
| id|order_month|orders|
+---+-----------+------+
| C1| 2016-01-01|    30|
| C1| 2016-02-01|  null|
| C1| 2016-03-01|  null|
| C1| 2016-04-01|  null|
| C1| 2016-05-01|  null|
| C1| 2016-06-01|  null|
| C1| 2016-07-01|  null|
| C1| 2016-08-01|  null|
| C1| 2016-09-01|  null|
| C1| 2016-10-01|  null|
| C1| 2016-11-01|  null|
| C1| 2016-12-01|  null|
| C1| 2017-01-01|    10|
| C1| 2017-02-01|    20|
| C1| 2017-03-01|    10|
+---+-----------+------+

使用这样的数据准备你可以使用窗口函数

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"id")
     .orderBy($"order_month")
    .rowsBetween(-12,Window.currentRow)

expanded.withColumn("rolling",sum("orders").over(w))
  .na.drop(Seq("orders"))
  .select(
      $"order_month" - expr("INTERVAL 12 MONTHS") as "period_start",$"order_month" as "period_end",$"rolling")
+------------+----------+-------+
|period_start|period_end|rolling|
+------------+----------+-------+
|  2015-01-01|2016-01-01|     30|
|  2016-01-01|2017-01-01|     40|
|  2016-02-01|2017-02-01|     30|
|  2016-03-01|2017-03-01|     40|
+------------+----------+-------+

请注意,这是一项非常昂贵的操作,需要至少两次洗牌:

== Physical Plan ==
*Project [cast(cast(order_month#104 as timestamp) - interval 1 years as date) AS period_start#1387,order_month#104 AS period_end#1388,rolling#1375L]
+- *Filter AtLeastNNulls(n,orders#55L)
   +- Window [sum(orders#55L) windowspecdeFinition(id#7,order_month#104 ASC NULLS FirsT,ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS rolling#1375L],[id#7],[order_month#104 ASC NULLS FirsT]
      +- *Sort [id#7 ASC NULLS FirsT,order_month#104 ASC NULLS FirsT],false,0
         +- Exchange hashpartitioning(id#7,200)
            +- *Project [id#7,order_month#104,orders#55L]
               +- *broadcastHashJoin [id#7,order_month#104],[id#181,order_month#49],LeftOuter,buildright
                  :- broadcastnestedLoopJoin buildright,Cross
                  :  :- *HashAggregate(keys=[id#7],functions=[])
                  :  :  +- Exchange hashpartitioning(id#7,200)
                  :  :     +- *HashAggregate(keys=[id#7],functions=[])
                  :  :        +- *HashAggregate(keys=[id#7,trunc(order_date#14,month)#1394],functions=[])
                  :  :           +- Exchange hashpartitioning(id#7,month)#1394,200)
                  :  :              +- *HashAggregate(keys=[id#7,month) AS trunc(order_date#14,functions=[])
                  :  :                 +- LocalTableScan [id#7,order_date#14]
                  :  +- broadcastExchange IdentitybroadcastMode
                  :     +- LocalTableScan [order_month#104]
                  +- broadcastExchange HashedRelationbroadcastMode(List(input[0,string,true],input[1,date,true]))
                     +- *HashAggregate(keys=[id#181,month)#1395],functions=[sum(cast(orders#183 as bigint))])
                        +- Exchange hashpartitioning(id#181,month)#1395,200)
                           +- *HashAggregate(keys=[id#181,functions=[partial_sum(cast(orders#183 as bigint))])
                              +- LocalTableScan [id#181,order_date#14,orders#183]

也可以使用rangeBetween帧来表达这一点,但您必须首先对数据进行编码:

val encoded = byMonth
  .withColumn("order_month_offset",// Choose "zero" date appropriate in your scenario
      months_between($"order_month",to_date(lit("1970-01-01"))))


val w = Window.partitionBy($"id")
  .orderBy($"order_month_offset")
  .rangeBetween(-12,Window.currentRow)

encoded.withColumn("rolling",sum($"orders").over(w))
+---+-----------+------+------------------+-------+                             
| id|order_month|orders|order_month_offset|rolling|
+---+-----------+------+------------------+-------+
| C1| 2016-01-01|    30|             552.0|     30|
| C1| 2017-01-01|    10|             564.0|     40|
| C1| 2017-02-01|    20|             565.0|     30|
| C1| 2017-03-01|    10|             566.0|     40|
+---+-----------+------+------------------+-------+

这将使参考的连接过时并简化执行计划.

相关文章

共收录Twitter的14款开源软件,第1页Twitter的Emoji表情 Tw...
Java和Scala中关于==的区别Java:==比较两个变量本身的值,即...
本篇内容主要讲解“Scala怎么使用”,感兴趣的朋友不妨来看看...
这篇文章主要介绍“Scala是一种什么语言”,在日常操作中,相...
这篇文章主要介绍“Scala Trait怎么使用”,在日常操作中,相...
这篇文章主要介绍“Scala类型检查与模式匹配怎么使用”,在日...