我尝试了各种不同的方法,从直接的递归解决方案开始,尝试使用流而不是列表的类似方法,并最终尝试通过使用索引进行更多操作来减少列表操作.我最终在更大的测试中有堆栈溢出异常,我可以使用Scala的TailCall修复.但是,虽然解决方案正确地解决了问题,但在1000毫秒内完成的速度太慢.除此之外,还有一个C实现显示比较快(<50ms).现在我明白在很多情况下Scala会比C慢,而且我也明白我可以在Scala中编写一个更具势在线的解决方案,它可能会表现得更好.不过,我想知道我是否遗漏了一些更基本的东西,因为我很难相信函数式编程总体上要慢得多(而且我对函数式编程很新). 这是我可以在REPL中粘贴的scala代码,包括> 1000ms的示例:
import scala.util.control.TailCalls._ def solve(l: List[(Int,Int)]): Int = { def go(from: Int,to: Int,prevHeight: Int): TailRec[Int] = { val max = to - from val currHeight = l.slice(from,to).minBy(_._1)._1 val hStrokes = currHeight - prevHeight val splits = l.slice(from,to).filter(_._1 - currHeight == 0).map(_._2) val indices = from :: splits.flatMap(x => List(x,x+1)) ::: List(to) val subLists = indices.grouped(2).filter(xs => xs.last - xs.head > 0) val trampolines = subLists.map(xs => tailcall(go(xs.head,xs.last,currHeight))) val sumTrampolines = trampolines.foldLeft(done(hStrokes))((b,a) => b.flatMap(bVal => a.map(aVal => aVal + bVal))) sumTrampolines.flatMap(v => done(max).map(m => Math.min(m,v))) } go(0,l.size,0).result } val lst = (1 to 5000).toList.zipWithIndex val res = solve(lst)
为了比较,这是一个实现Bugman编写的相同内容的C示例(包括我在上面的Scala版本中未包含的控制台的一些读/写):
#include <iostream> #include <cstdio> #include <cstdlib> #include <algorithm> #include <vector> #include <string> #include <set> #include <map> #include <cmath> #include <memory.h> using namespace std; typedef long long ll; const int N = 1e6+6; const int T = 1e6+6; int a[N]; int t[T],d; int rmq(int i,int j){ int r = i; for(i+=d,j+=d; i<=j; ++i>>=1,--j>>=1){ if(i&1) r=a[r]>a[t[i]]?t[i]:r; if(~j&1) r=a[r]>a[t[j]]?t[j]:r; } return r; } int calc(int l,int r,int h){ if(l>r) return 0; int m = rmq(l,r); int mn = a[m]; int res = min(r-l+1,calc(l,m-1,mn)+calc(m+1,r,mn)+mn-h); return res; } int main(){ //freopen("input.txt","r",stdin);// freopen("output.txt","w",stdout); int n,m; scanf("%d",&n); for(int i=0;i<n;++i) scanf("%d",&a[i]); a[n] = 2e9; for(d=1;d<n;d<<=1); for(int i=0;i<n;++i) t[i+d]=i; for(int i=n+d;i<d+d;++i) t[i]=n; for(int i=d-1;i;--i) t[i]=a[t[i*2]]<a[t[i*2+1]]?t[i*2]:t[i*2+1]; printf("%d\n",calc(0,n-1,0)); return 0; }
至少在我介绍显式尾部调用之前,对于我来说,解决问题比使用更强制性的解决方案更自然.因此,我非常乐意在编写功能代码时更多地了解我应该注意什么,以便仍能获得可接受的性能.
解决方法
这是一个无索引的实现:
import scala.util.control.TailCalls._ def solve(xs: Vector[Int]): Int = { def go(xs: Vector[Int],previous: Int): TailRec[Int] = { val min = xs.min splitOn(xs,min).foldLeft(done(min - previous)) { case (acc,part) => for { total <- acc cost <- go(part,min) } yield total + cost }.map(math.min(xs.size,_)) } go(xs,0).result }
尽管如此,这并不是完整的故事 – 我将分裂部分分解为一个名为splitOn的方法,该方法采用序列和分隔符.因为这是一个非常简单和通用的操作,所以它是优化的良好候选者.以下是一个快速尝试:
def splitOn[A](xs: Vector[A],delim: A): Vector[Vector[A]] = { val builder = Vector.newBuilder[Vector[A]] var i = 0 var start = 0 while (i < xs.size) { if (xs(i) == delim) { if (i != start) { builder += xs.slice(start,i) } start = i + 1 } i += 1 } if (i != start) builder += xs.slice(start,i) builder.result }
虽然这种实现是必要的,但从外部来看,该方法功能完善 – 它没有副作用等.
这通常是提高功能代码性能的好方法:我们将程序分为通用部分(在分隔符上拆分列表)和特定于问题的逻辑.因为前者非常简单,我们可以要求它(并测试它)作为一个黑盒子,同时保持我们用来解决问题的代码清洁和功能.
在这种情况下,性能仍然不是很好 – 这个实现的速度大约是我的机器的两倍 – 但我不认为在使用TailCalls进行蹦床时你会比这更好.