In Java, how do I efficiently and elegantly stream a tree node's descendants?
Assume we have a collection of objects that are identified by unique String
s, along with a class Tree
that defines a hierarchy on them. That class is implemented using a Map
from nodes (represented by their IDs) to Collection
s of their respective children's IDs.
class Tree {
private Map<String, Collection<String>> edges;
// ...
public Stream<String> descendants(String node) {
// To be defined.
}
}
I would like to enable streaming a node's descendants. A simple solution is this:
private Stream<String> children(String node) {
return edges.getOrDefault(node, Collections.emptyList()).stream();
}
public Stream<String> descendants(String node) {
return Stream.concat(
Stream.of(node),
children(node).flatMap(this::descendants)
);
}
Before continuing, I would like to make the following assertions about this solution. (Am I correct about these?)
Walking the
Stream
returned fromdescendants
consumes resources (time and memory) – relative to the size of the tree – in the same order of complexity as hand-coding the recursion would. In particular, the intermediate objects representing the iteration state (Stream
s,Spliterator
s, ...) form a stack and therefore the memory requirement at any given time is in the same order of complexity as the tree's depth.As I understand this, as soon as I perform a terminating operation on the
Stream
returned fromdescendants
, the root-level call toflatMap
will cause all containedStream
s – one for each (recursive) call todescendants
– to be realized immediately. Thus, the resultingStream
is only lazy on the first level of recursion, but not beyond. (Edited according to Tagir Valeevs answer.)
If I understood these points correctly, my question is this: How can I define descendants
so that the resulting Stream
is lazy?
I would like the solution to be as elegant as possible, in the sense that I prefer a solution which leaves the iteration state implicit. (To clarify what I mean by that: I know that I could write a Spliterator
that walks the tree while maintaining an explicit stack of Spliterator
s on each level. I would like to avoid that.)
(Is there possibly a way in Java to formulate this as a producer-consumer workflow, like one could use in languages like Julia and Go?)
To me, your solution is already as elegant as possible and the limited laziness of it not your fault. The simplest solution is to wait until it gets fixed by the JRE developers. It has been done with Java 10.
However, if this limited laziness of today’s implementation really is a concern, it’s perhaps time to solve this in a general way. Well, it is about implementing a Spliterator
, but not specific to your task. Instead, it’s a re-implementation of the flatmap
operation serving all cases where the limited laziness of the original implementation matters:
public class FlatMappingSpliterator<E,S> extends Spliterators.AbstractSpliterator<E>
implements Consumer<S> {
static final boolean USE_ORIGINAL_IMPL
= Boolean.getBoolean("stream.flatmap.usestandard");
public static <T,R> Stream<R> flatMap(
Stream<T> in, Function<? super T,? extends Stream<? extends R>> mapper) {
if(USE_ORIGINAL_IMPL)
return in.flatMap(mapper);
Objects.requireNonNull(in);
Objects.requireNonNull(mapper);
return StreamSupport.stream(
new FlatMappingSpliterator<>(sp(in), mapper), in.isParallel()
).onClose(in::close);
}
final Spliterator<S> src;
final Function<? super S, ? extends Stream<? extends E>> f;
Stream<? extends E> currStream;
Spliterator<E> curr;
private FlatMappingSpliterator(
Spliterator<S> src, Function<? super S, ? extends Stream<? extends E>> f) {
// actually, the mapping function can change the size to anything,
// but it seems, with the current stream implementation, we are
// better off with an estimate being wrong by magnitudes than with
// reporting unknown size
super(src.estimateSize()+100, src.characteristics()&ORDERED);
this.src = src;
this.f = f;
}
private void closeCurr() {
try { currStream.close(); } finally { currStream=null; curr=null; }
}
public void accept(S s) {
curr=sp(currStream=f.apply(s));
}
@Override
public boolean tryAdvance(Consumer<? super E> action) {
do {
if(curr!=null) {
if(curr.tryAdvance(action))
return true;
closeCurr();
}
} while(src.tryAdvance(this));
return false;
}
@Override
public void forEachRemaining(Consumer<? super E> action) {
if(curr!=null) {
curr.forEachRemaining(action);
closeCurr();
}
src.forEachRemaining(s->{
try(Stream<? extends E> str=f.apply(s)) {
if(str!=null) str.spliterator().forEachRemaining(action);
}
});
}
@SuppressWarnings("unchecked")
private static <X> Spliterator<X> sp(Stream<? extends X> str) {
return str!=null? ((Stream<X>)str).spliterator(): null;
}
@Override
public Spliterator<E> trySplit() {
Spliterator<S> split = src.trySplit();
if(split==null) {
Spliterator<E> prefix = curr;
while(prefix==null && src.tryAdvance(s->curr=sp(f.apply(s))))
prefix=curr;
curr=null;
return prefix;
}
FlatMappingSpliterator<E,S> prefix=new FlatMappingSpliterator<>(split, f);
if(curr!=null) {
prefix.curr=curr;
curr=null;
}
return prefix;
}
}
All you need for using it, is to add a import static
of the flatMap
method to your code and change expressions of the form stream.flatmap(function)
to flatmap(stream, function)
.
I.e. in your code
public Stream<String> descendants(String node) {
return Stream.concat(
Stream.of(node),
flatMap(children(node), this::descendants)
);
}
then you have full lazy behavior. I tested it even with infinite streams…
Note that I added a toggle to allow turning back to the original implementation, e.g. when specifying -Dstream.flatmap.usestandard=true
on the command line.
You're a little bit wrong saying that the flatMap
stream is not lazy. It somewhat lazy, though it's laziness is really limited. Let's use some custom Collection
to track the requested elements inside your Tree
class:
private final Set<String> requested = new LinkedHashSet<>();
private class MyList extends AbstractList<String> implements RandomAccess
{
private final String[] data;
public MyList(String... data) {
this.data = data;
}
@Override
public String get(int index) {
requested.add(data[index]);
return data[index];
}
@Override
public int size() {
return data.length;
}
}
Now let's pre-initialize your class with some tree data:
public Tree() {
// "1" is the root note, contains three immediate descendants
edges.put("1", new MyList("2", "3", "4"));
edges.put("2", new MyList("5", "6", "7"));
edges.put("3", new MyList("8", "9", "10"));
edges.put("8", new MyList("11", "12"));
edges.put("5", new MyList("13", "14", "15"));
edges.put("7", new MyList("16", "17", "18"));
edges.put("6", new MyList("19", "20"));
}
Finally let's check how many elements are actually requested from your list on different limit values:
public static void main(String[] args) {
for(int i=1; i<=20; i++) {
Tree tree = new Tree();
tree.descendants("1").limit(i).toArray();
System.out.println("Limit = " + i + "; requested = (" + tree.requested.size()
+ ") " + tree.requested);
}
}
The output is the following:
Limit = 1; requested = (0) []
Limit = 2; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 3; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 4; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 5; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 6; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 7; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 8; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 9; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 10; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 11; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 12; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 13; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 14; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 15; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 16; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 17; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 18; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 19; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 20; requested = (19) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10, 4]
Thus when only the root note is requested, no access to children is performed (as Stream.concat
is smart). When the first immediate child is requested, the whole subtree for this child is processed even if it's unnecessary. Nevertheless the second immediate child is not processed until the first one finishes. This could be problematic for short-circuiting scenarios, but in most of the cases your terminal operation is not short-circuiting, thus it's still fine approach.
As for your concerns about memory consumption: yes, it eats the memory according to the tree depth (and more importantly it eats the stack). If your tree has thousands nesting levels, you will have the problem with your solution as you may hit StackOverflowError
for default -Xss
setting. For several hundreds levels of depth it would work fine.
We are using similar approach in business-logic layer of our application, it works fine for us, though our trees are rarely deeper than 10 levels.
Not a real answer, but just a thought:
If you peek into the value collection and on the next step "resolve" that last seen value to a new value collection returning the next values in the same way recursively, then however this is implemented, it will always end up with some kind of "pointer" to the current element in the value collection on the current "level" of depth in the tree, and also with some kind of stack holding all those "pointers".
This is, because you need both the information about the higher levels in the tree (stack) and a "pointer" to the current element at the current level. In this case, one causes the other.
Of course, you can implement this as a Spliterator
that holds a Stack of Iterators (pointing to the corresponding value collection), but I suppose there will always be a "pointer" state at each depth level, even if it's hidden in Java's flatMap (or related) temporary objects.
As an alternative: how about using a "real" tree with nodes that hold a reference to its parent node? Plus, adding a map to the root of the tree which holds a reference to all single nodes to simplify the access to a sub-sub-sub-child. I guess the Spliterator
implementation would then be really easy because it just needs a reference to the current node for traversing and a stop criteria (the initial node value) to stop walking too "high" up in the tree.