问题描述
我正在尝试实现一个示例:
https://portal.klewel.com/watch/webcast/scala-days-2019/talk/37/
使用Scala延续:
object ReverseGrad_cpsImproved {
import scala.util.continuations._
case class Num(
x: Double,var d: Double = 0.0
) {
def +(that: Num) = shift { (cont: Num => Unit) =>
val y = Num(x + that.x)
cont(y)
this.d += y.d
that.d += y.d
}
def *(that: Num) = shift { (cont: Num => Unit) =>
val y = Num(x * that.x)
cont(y)
this.d += that.x * y.d
that.d += this.x * y.d
}
}
object Num {
implicit def fromX(x: Double): Num = Num(x)
}
def grad(f: Num => Num @cps[Unit])(x: Double): Double = {
val _x = Num(x)
reset { f(_x).d = 1.0 }
_x.d
}
}
只要我使用的是简单表达式,它就可以工作:
it("simple") {
val fn = { x: Num =>
val result = (x + 3) * (x + 4)
result
}
val gg = grad(fn)(3)
println(gg)
}
但是一旦我开始使用循环,一切都会崩溃:
it("benchmark") {
import scala.util.continuations._
for (i <- 1 to 20) {
val n = Math.pow(2,i).toInt
val fn = { x: Num =>
var result = x + 1
for (j <- 2 to n) {
result = result * (x + j)
}
result
}
val nanoFrom = System.nanoTime()
val gg = grad(fn)(3)
val nanoTo = System.nanoTime()
println(s"diff = $gg,\t time = ${nanoTo - nanoFrom}")
}
}
[Error] /home/peng/git-spike/scalaspike/Meta/src/test/scala/com/tribbloids/spike/Meta/multistage/lms/ReverseGrad_cpsImproved.scala:78: found cps expression in non-cps position
one error found
我的印象是,延续库应该有自己的循环实现,可以重写为递归,但是在最新版本(scala 2.12)中,我找不到它。在这种情况下,使用循环最简单的方法是什么?
解决方法
在CPS中,您必须重写代码,以免在同一上下文中执行嵌套/迭代/递归调用,而仅执行计算的一个步骤并将部分结果向前传递。
例如如果您想计算数字A到B的乘积,则可以通过以下方式实现:
import scala.util.continuations._
case class Num(toDouble: Double) {
def get = shift { cont: (Num => Num) =>
cont(this)
}
def +(num: Num) = reset {
val a = num.get
Num(toDouble + a.toDouble)
}
def *(num: Num) = reset {
val a = num.get
Num(toDouble * a.toDouble)
}
}
// type annotation required because of recursive call
def product(from: Int,to: Int): Num @cps[Num] = reset {
if (from > to) Num(1.toDouble)
else Num(from.toDouble) * product(from + 1,to)
}
def run: Num = reset {
product(2,10)
}
println(run)
(请参阅此scastie)。
最有趣的是这个片段:
reset {
if (from > to) Num(1.toDouble)
else Num(from.toDouble) * product(from + 1,to)
}
在这里,编译器(插件)将其重写为类似于:
input: (Num => Num) => {
if (from > to) Num(1.toDouble)
else {
Num(from.toDouble) * product(from + 1,to) // this is virtually (Num => Num) => Num function!
} (input)
}
编译器之所以可以这样做,是因为:
- 它观察
shift
和reset
呼叫的内容- 都创建一些带有参数
A
并返回中间结果B
(例如在此reset
中使用)和最终结果C
(得到的结果)当您运行合成的最终结果时(表示为A @ cpsParam[B,C]
-如果B =:= C
可以使用类型别名A @ cps[A]
) -
reset
使处理参数({{1}中的A
)并将其传递给所有嵌套的CPS调用并获得中间结果变得更加轻松,而不会传递参数(因此A @ cpsParam[B,C]
中的B
)并使整个块返回最终结果-A @ cpsParam[B,C]
C
-
A @ cpsParam[B,C]
将功能shift
提升到(A => B) => C
- 都创建一些带有参数
- 当看到返回类型为
A @ cpsParam[B,C]
时,它知道应该重写代码以引入参数并将其传递给该参数。
实际上,它的底层要复杂得多,但基本上就是这样。
与此同时
Input @cpsParam[Output1,Output2]
在此上下文之外,编译器不执行任何转换。您至少必须在 for (j <- 2 to n) {
result = result * (x + j)
}
中编写所有CPS操作。 (此外,您可以循环运行并进行变异,也可以委派给CPS。)
表示CPS(例如:此特定实现)已死。它已在Scala 2.13中删除,没有人支持它,使用一些基于蹦床的monad(例如Cats的reset
)更容易理解,因此,我仍然看到的唯一地方是过时的课程或有关历史琐事的文章。