Scala equivalent to Python generators?
Is it possible to implement in Scala something equivalent to the Python yield
statement where it remembers the local state of the function where it is used and "yields" the next value each time it is called?
I wanted to have something like this to convert a recursive function into an iterator. Sort of like this:
# this is python
def foo(i):
yield i
if i > 0:
for j in foo(i - 1):
yield j
for i in foo(5):
print i
Except, foo
may be more complex and recurs through some acyclic object graph.
Additional Edit: Let me add a more complex example (but still simple): I can write a simple recursive function printing things as it goes along:
// this is Scala
def printClass(clazz:Class[_], indent:String=""): Unit = {
clazz match {
case null =>
case _ =>
println(indent + clazz)
printClass(clazz.getSuperclass, indent + " ")
for (c <- clazz.getInterfaces) {
printClass(c, indent + " ")
}
}
}
Ideally I would like to have a library that allows me to easily change a few statements and have it work as an Iterator:
// this is not Scala
def yieldClass(clazz:Class[_]): Iterator[Class[_]] = {
clazz match {
case null =>
case _ =>
sudoYield clazz
for (c <- yieldClass(clazz.getSuperclass)) sudoYield c
for (c <- clazz.getInterfaces; d <- yieldClasss(c)) sudoYield d
}
}
It does seem continuations allow to do that, but I just don't understand the shift/reset
concept. Will continuation eventually make it into the main compiler and would it be possible to extract out the complexity in a library?
Edit 2: check Rich's answer in that other thread.
Solution 1:
While Python generators are cool, trying to duplicate them really isn't the best way to go about in Scala. For instance, the following code does the equivalent job to what you want:
def classStream(clazz: Class[_]): Stream[Class[_]] = clazz match {
case null => Stream.empty
case _ => (
clazz
#:: classStream(clazz.getSuperclass)
#::: clazz.getInterfaces.toStream.flatMap(classStream)
#::: Stream.empty
)
}
In it the stream is generated lazily, so it won't process any of the elements until asked for, which you can verify by running this:
def classStream(clazz: Class[_]): Stream[Class[_]] = clazz match {
case null => Stream.empty
case _ => (
clazz
#:: { println(clazz.toString+": super"); classStream(clazz.getSuperclass) }
#::: { println(clazz.toString+": interfaces"); clazz.getInterfaces.toStream.flatMap(classStream) }
#::: Stream.empty
)
}
The result can be converted into an Iterator
simply by calling .iterator
on the resulting Stream
:
def classIterator(clazz: Class[_]): Iterator[Class[_]] = classStream(clazz).iterator
The foo
definition, using Stream
, would be rendered thus:
scala> def foo(i: Int): Stream[Int] = i #:: (if (i > 0) foo(i - 1) else Stream.empty)
foo: (i: Int)Stream[Int]
scala> foo(5) foreach println
5
4
3
2
1
0
Another alternative would be concatenating the various iterators, taking care to not pre-compute them. Here's an example, also with debugging messages to help trace the execution:
def yieldClass(clazz: Class[_]): Iterator[Class[_]] = clazz match {
case null => println("empty"); Iterator.empty
case _ =>
def thisIterator = { println("self of "+clazz); Iterator(clazz) }
def superIterator = { println("super of "+clazz); yieldClass(clazz.getSuperclass) }
def interfacesIterator = { println("interfaces of "+clazz); clazz.getInterfaces.iterator flatMap yieldClass }
thisIterator ++ superIterator ++ interfacesIterator
}
This is pretty close to your code. Instead of sudoYield
, I have definitions, and then I just concatenate them as I wish.
So, while this is a non-answer, I just think you are barking up the wrong tree here. Trying to write Python in Scala is bound to be unproductive. Work harder at the Scala idioms that accomplish the same goals.
Solution 2:
Another continuations plugin based solution, this time with a more or less encapsulated Generator type,
import scala.continuations._
import scala.continuations.ControlContext._
object Test {
def loopWhile(cond: =>Boolean)(body: =>(Unit @suspendable)): Unit @suspendable = {
if (cond) {
body
loopWhile(cond)(body)
} else ()
}
abstract class Generator[T] {
var producerCont : (Unit => Unit) = null
var consumerCont : (T => Unit) = null
protected def body : Unit @suspendable
reset {
body
}
def generate(t : T) : Unit @suspendable =
shift {
(k : Unit => Unit) => {
producerCont = k
if (consumerCont != null)
consumerCont(t)
}
}
def next : T @suspendable =
shift {
(k : T => Unit) => {
consumerCont = k
if (producerCont != null)
producerCont()
}
}
}
def main(args: Array[String]) {
val g = new Generator[Int] {
def body = {
var i = 0
loopWhile(i < 10) {
generate(i)
i += 1
}
}
}
reset {
loopWhile(true) {
println("Generated: "+g.next)
}
}
}
}
Solution 3:
To do this in a general way, I think you need the continuations plugin.
A naive implementation (freehand, not compiled/checked):
def iterator = new {
private[this] var done = false
// Define your yielding state here
// This generator yields: 3, 13, 0, 1, 3, 6, 26, 27
private[this] var state: Unit=>Int = reset {
var x = 3
giveItUp(x)
x += 10
giveItUp(x)
x = 0
giveItUp(x)
List(1,2,3).foreach { i => x += i; giveItUp(x) }
x += 20
giveItUp(x)
x += 1
done = true
x
}
// Well, "yield" is a keyword, so how about giveItUp?
private[this] def giveItUp(i: Int) = shift { k: (Unit=>Int) =>
state = k
i
}
def hasNext = !done
def next = state()
}
What is happening is that any call to shift
captures the control flow from where it is called to the end of the reset
block that it is called in. This is passed as the k
argument into the shift function.
So, in the example above, each giveItUp(x)
returns the value of x
(up to that point) and saves the rest of the computation in the state
variable. It is driven from outside by the hasNext
and next
methods.
Go gentle, this is obviously a terrible way to implement this. But it best I could do late at night without a compiler handy.
Solution 4:
Scala's for-loop of the form for (e <- Producer) f(e)
translates into a foreach
call, and not directly into iterator
/ next
.
In the foreach
we don't need to linearize objects' creations and have them in one place, as it would be needed for iterator's next
. The consumer-function f
can be inserted multiple times, exactly where it is needed (i.e. where an object is created).
This makes implementation of use cases for generators simple and efficient with Traversable
/ foreach
in Scala.
The initial Foo-example:
case class Countdown(start: Int) extends Traversable[Int] {
def foreach[U](f: Int => U) {
var j = start
while (j >= 0) {f(j); j -= 1}
}
}
for (i <- Countdown(5)) println(i)
// or equivalent:
Countdown(5) foreach println
The initial printClass-example:
// v1 (without indentation)
case class ClassStructure(c: Class[_]) {
def foreach[U](f: Class[_] => U) {
if (c eq null) return
f(c)
ClassStructure(c.getSuperclass) foreach f
c.getInterfaces foreach (ClassStructure(_) foreach f)
}
}
for (c <- ClassStructure(<foo/>.getClass)) println(c)
// or equivalent:
ClassStructure(<foo/>.getClass) foreach println
Or with indentation:
// v2 (with indentation)
case class ClassWithIndent(c: Class[_], indent: String = "") {
override def toString = indent + c
}
implicit def Class2WithIndent(c: Class[_]) = ClassWithIndent(c)
case class ClassStructure(cwi: ClassWithIndent) {
def foreach[U](f: ClassWithIndent => U) {
if (cwi.c eq null) return
f(cwi)
ClassStructure(ClassWithIndent(cwi.c.getSuperclass, cwi.indent + " ")) foreach f
cwi.c.getInterfaces foreach (i => ClassStructure(ClassWithIndent(i, cwi.indent + " ")) foreach f)
}
}
for (c <- ClassStructure(<foo/>.getClass)) println(c)
// or equivalent:
ClassStructure(<foo/>.getClass) foreach println
Output:
class scala.xml.Elem
class scala.xml.Node
class scala.xml.NodeSeq
class java.lang.Object
interface scala.collection.immutable.Seq
interface scala.collection.immutable.Iterable
interface scala.collection.immutable.Traversable
interface scala.collection.Traversable
interface scala.collection.TraversableLike
interface scala.collection.generic.HasNewBuilder
interface scala.collection.generic.FilterMonadic
interface scala.collection.TraversableOnce
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.generic.GenericTraversableTemplate
interface scala.collection.generic.HasNewBuilder
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.generic.GenericTraversableTemplate
interface scala.collection.generic.HasNewBuilder
interface scala.ScalaObject
interface scala.collection.TraversableLike
interface scala.collection.generic.HasNewBuilder
interface scala.collection.generic.FilterMonadic
interface scala.collection.TraversableOnce
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.Immutable
interface scala.ScalaObject
interface scala.collection.Iterable
interface scala.collection.Traversable
interface scala.collection.TraversableLike
interface scala.collection.generic.HasNewBuilder
interface scala.collection.generic.FilterMonadic
interface scala.collection.TraversableOnce
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.generic.GenericTraversableTemplate
interface scala.collection.generic.HasNewBuilder
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.generic.GenericTraversableTemplate
interface scala.collection.generic.HasNewBuilder
interface scala.ScalaObject
interface scala.collection.IterableLike
interface scala.Equals
interface scala.collection.TraversableLike
interface scala.collection.generic.HasNewBuilder
interface scala.collection.generic.FilterMonadic
interface scala.collection.TraversableOnce
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.generic.GenericTraversableTemplate
interface scala.collection.generic.HasNewBuilder
interface scala.ScalaObject
interface scala.collection.IterableLike
interface scala.Equals
interface scala.collection.TraversableLike
interface scala.collection.generic.HasNewBuilder
interface scala.collection.generic.FilterMonadic
interface scala.collection.TraversableOnce
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.Seq
interface scala.PartialFunction
interface scala.Function1
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.Iterable
interface scala.collection.Traversable
interface scala.collection.TraversableLike
interface scala.collection.generic.HasNewBuilder
interface scala.collection.generic.FilterMonadic
interface scala.collection.TraversableOnce
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.generic.GenericTraversableTemplate
interface scala.collection.generic.HasNewBuilder
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.generic.GenericTraversableTemplate
interface scala.collection.generic.HasNewBuilder
interface scala.ScalaObject
interface scala.collection.IterableLike
interface scala.Equals
interface scala.collection.TraversableLike
interface scala.collection.generic.HasNewBuilder
interface scala.collection.generic.FilterMonadic
interface scala.collection.TraversableOnce
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.generic.GenericTraversableTemplate
interface scala.collection.generic.HasNewBuilder
interface scala.ScalaObject
interface scala.collection.SeqLike
interface scala.collection.IterableLike
interface scala.Equals
interface scala.collection.TraversableLike
interface scala.collection.generic.HasNewBuilder
interface scala.collection.generic.FilterMonadic
interface scala.collection.TraversableOnce
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.generic.GenericTraversableTemplate
interface scala.collection.generic.HasNewBuilder
interface scala.ScalaObject
interface scala.collection.SeqLike
interface scala.collection.IterableLike
interface scala.Equals
interface scala.collection.TraversableLike
interface scala.collection.generic.HasNewBuilder
interface scala.collection.generic.FilterMonadic
interface scala.collection.TraversableOnce
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.collection.SeqLike
interface scala.collection.IterableLike
interface scala.Equals
interface scala.collection.TraversableLike
interface scala.collection.generic.HasNewBuilder
interface scala.collection.generic.FilterMonadic
interface scala.collection.TraversableOnce
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.xml.Equality
interface scala.Equals
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface scala.ScalaObject
interface java.io.Serializable