How to optimize for-comprehensions and loops in Scala?
Solution 1:
The problem in this particular case is that you return from within the for-expression. That in turn gets translated into a throw of a NonLocalReturnException, which is caught at the enclosing method. The optimizer can eliminate the foreach but cannot yet eliminate the throw/catch. And throw/catch is expensive. But since such nested returns are rare in Scala programs, the optimizer did not yet address this case. There is work going on to improve the optimizer which hopefully will solve this issue soon.
Solution 2:
The problem is most likely the use of a for
comprehension in the method isEvenlyDivisible
. Replacing for
by an equivalent while
loop should eliminate the performance difference with Java.
As opposed to Java's for
loops, Scala's for
comprehensions are actually syntactic sugar for higher-order methods; in this case, you're calling the foreach
method on a Range
object. Scala's for
is very general, but sometimes leads to painful performance.
You might want to try the -optimize
flag in Scala version 2.9. Observed performance may depend on the particular JVM in use, and the JIT optimizer having sufficient "warm up" time to identify and optimize hot-spots.
Recent discussions on the mailing list indicate that the Scala team is working on improving for
performance in simple cases:
- http://groups.google.com/group/scala-user/browse_thread/thread/86adb44d72ef4498
- http://groups.google.com/group/scala-language/browse_thread/thread/94740a10205dddd2
Here is the issue in the bug tracker: https://issues.scala-lang.org/browse/SI-4633
Update 5/28:
- As a short term solution, the ScalaCL plugin (alpha) will transform simple Scala loops into the equivalent of
while
loops. - As a potential longer term solution, teams from the EPFL and Stanford are collaborating on a project enabling run-time compilation of "virtual" Scala for very high performance. For example, multiple idiomatic functional loops can be fused at run-time into optimal JVM bytecode, or to another target such as a GPU. The system is extensible, allowing user defined DSLs and transformations. Check out the publications and Stanford course notes. Preliminary code is available on Github, with a release intended in the coming months.
Solution 3:
As a follow-up, I tried the -optimize flag and it reduced running time from 103 to 76 seconds, but that's still 107x slower than Java or a while loop.
Then I was looking at the "functional" version:
object P005 extends App{
def isDivis(x:Int) = (1 to 20) forall {x % _ == 0}
def find(n:Int):Int = if (isDivis(n)) n else find (n+2)
println (find (2))
}
and trying to figure out how to get rid of the "forall" in a concise manner. I failed miserably and came up with
object P005_V2 extends App {
def isDivis(x:Int):Boolean = {
var i = 1
while(i <= 20) {
if (x % i != 0) return false
i += 1
}
return true
}
def find(n:Int):Int = if (isDivis(n)) n else find (n+2)
println (find (2))
}
whereby my cunning 5-line solution has balooned to 12 lines. However, this version runs in 0.71 seconds, the same speed as the original Java version, and 56 times faster than the version above using "forall" (40.2 s)! (see EDIT below for why this is faster than Java)
Obviously my next step was to translate the above back into Java, but Java can't handle it and throws a StackOverflowError with n around the 22000 mark.
I then scratched my head for a bit and replaced the "while" with a bit more tail recursion, which saves a couple of lines, runs just as fast, but let's face it, is more confusing to read:
object P005_V3 extends App {
def isDivis(x:Int, i:Int):Boolean =
if(i > 20) true
else if(x % i != 0) false
else isDivis(x, i+1)
def find(n:Int):Int = if (isDivis(n, 2)) n else find (n+2)
println (find (2))
}
So Scala's tail recursion wins the day, but I'm surprised that something as simple as a "for" loop (and the "forall" method) is essentially broken and has to be replaced by inelegant and verbose "whiles", or tail recursion. A lot of the reason I'm trying Scala is because of the concise syntax, but it's no good if my code is going to run 100 times slower!
EDIT: (deleted)
EDIT OF EDIT: Former discrepancies between run times of 2.5s and 0.7s were entirely due to whether the 32-bit or 64-bit JVMs were being used. Scala from the command line uses whatever is set by JAVA_HOME, while Java uses 64-bit if available regardless. IDEs have their own settings. Some measurements here: Scala execution times in Eclipse
Solution 4:
The answer about for comprehension is right, but it's not the whole story. You should note note that the use of return
in isEvenlyDivisible
is not free. The use of return inside the for
, forces the scala compiler to generate a non-local return (i.e. to return outside it's function).
This is done through the use of an exception to exit the loop. The same happens if you build your own control abstractions, for example:
def loop[T](times: Int, default: T)(body: ()=>T) : T = {
var count = 0
var result: T = default
while(count < times) {
result = body()
count += 1
}
result
}
def foo() : Int= {
loop(5, 0) {
println("Hi")
return 5
}
}
foo()
This prints "Hi" only once.
Note that the return
in foo
exits foo
(which is what you would expect). Since the bracketed expression is a function literal, which you can see in the signature of loop
this forces the compiler to generate a non local return, that is, the return
forces you to exit foo
, not just the body
.
In Java (i.e. the JVM) the only way to implement such behavior is to throw an exception.
Going back to isEvenlyDivisible
:
def isEvenlyDivisible(a:Int, b:Int):Boolean = {
for (i <- 2 to b)
if (a % i != 0) return false
return true
}
The if (a % i != 0) return false
is a function literal that has a return, so each time the return is hit, the runtime has to throw and catch an exception, which causes quite a bit of GC overhead.
Solution 5:
Some ways to speed up the forall
method I discovered:
The original: 41.3 s
def isDivis(x:Int) = (1 to 20) forall {x % _ == 0}
Pre-instantiating the range, so we don't create a new range every time: 9.0 s
val r = (1 to 20)
def isDivis(x:Int) = r forall {x % _ == 0}
Converting to a List instead of a Range: 4.8 s
val rl = (1 to 20).toList
def isDivis(x:Int) = rl forall {x % _ == 0}
I tried a few other collections but List was fastest (although still 7x slower than if we avoid the Range and higher-order function altogether).
While I am new to Scala, I'd guess the compiler could easily implement a quick and significant performance gain by simply automatically replacing Range literals in methods (as above) with Range constants in the outermost scope. Or better, intern them like Strings literals in Java.
footnote:
Arrays were about the same as Range, but interestingly, pimping a new forall
method (shown below) resulted in 24% faster execution on 64-bit, and 8% faster on 32-bit. When I reduced the calculation size by reducing the number of factors from 20 to 15 the difference disappeared, so maybe it's a garbage collection effect. Whatever the cause, it's significant when operating under full load for extended periods.
A similar pimp for List also resulted in about 10% better performance.
val ra = (1 to 20).toArray
def isDivis(x:Int) = ra forall2 {x % _ == 0}
case class PimpedSeq[A](s: IndexedSeq[A]) {
def forall2 (p: A => Boolean): Boolean = {
var i = 0
while (i < s.length) {
if (!p(s(i))) return false
i += 1
}
true
}
}
implicit def arrayToPimpedSeq[A](in: Array[A]): PimpedSeq[A] = PimpedSeq(in)