In Java, how do I efficiently and elegantly stream a tree node's descendants?

Assume we have a collection of objects that are identified by unique Strings, along with a class Tree that defines a hierarchy on them. That class is implemented using a Map from nodes (represented by their IDs) to Collections of their respective children's IDs.

class Tree {
  private Map<String, Collection<String>> edges;

  // ...

  public Stream<String> descendants(String node) {
    // To be defined.
  }
}

I would like to enable streaming a node's descendants. A simple solution is this:

private Stream<String> children(String node) {
    return edges.getOrDefault(node, Collections.emptyList()).stream();
}

public Stream<String> descendants(String node) {
    return Stream.concat(
        Stream.of(node),
        children(node).flatMap(this::descendants)
    );
}

Before continuing, I would like to make the following assertions about this solution. (Am I correct about these?)

  1. Walking the Stream returned from descendants consumes resources (time and memory) – relative to the size of the tree – in the same order of complexity as hand-coding the recursion would. In particular, the intermediate objects representing the iteration state (Streams, Spliterators, ...) form a stack and therefore the memory requirement at any given time is in the same order of complexity as the tree's depth.

  2. As I understand this, as soon as I perform a terminating operation on the Stream returned from descendants, the root-level call to flatMap will cause all contained Streams – one for each (recursive) call to descendants – to be realized immediately. Thus, the resulting Stream is only lazy on the first level of recursion, but not beyond. (Edited according to Tagir Valeevs answer.)

If I understood these points correctly, my question is this: How can I define descendants so that the resulting Stream is lazy?

I would like the solution to be as elegant as possible, in the sense that I prefer a solution which leaves the iteration state implicit. (To clarify what I mean by that: I know that I could write a Spliterator that walks the tree while maintaining an explicit stack of Spliterators on each level. I would like to avoid that.)

(Is there possibly a way in Java to formulate this as a producer-consumer workflow, like one could use in languages like Julia and Go?)


To me, your solution is already as elegant as possible and the limited laziness of it not your fault. The simplest solution is to wait until it gets fixed by the JRE developers. It has been done with Java 10.

However, if this limited laziness of today’s implementation really is a concern, it’s perhaps time to solve this in a general way. Well, it is about implementing a Spliterator, but not specific to your task. Instead, it’s a re-implementation of the flatmap operation serving all cases where the limited laziness of the original implementation matters:

public class FlatMappingSpliterator<E,S> extends Spliterators.AbstractSpliterator<E>
implements Consumer<S> {

    static final boolean USE_ORIGINAL_IMPL
        = Boolean.getBoolean("stream.flatmap.usestandard");

    public static <T,R> Stream<R> flatMap(
        Stream<T> in, Function<? super T,? extends Stream<? extends R>> mapper) {

        if(USE_ORIGINAL_IMPL)
            return in.flatMap(mapper);

        Objects.requireNonNull(in);
        Objects.requireNonNull(mapper);
        return StreamSupport.stream(
            new FlatMappingSpliterator<>(sp(in), mapper), in.isParallel()
        ).onClose(in::close);
    }

    final Spliterator<S> src;
    final Function<? super S, ? extends Stream<? extends E>> f;
    Stream<? extends E> currStream;
    Spliterator<E> curr;

    private FlatMappingSpliterator(
        Spliterator<S> src, Function<? super S, ? extends Stream<? extends E>> f) {
        // actually, the mapping function can change the size to anything,
        // but it seems, with the current stream implementation, we are
        // better off with an estimate being wrong by magnitudes than with
        // reporting unknown size
        super(src.estimateSize()+100, src.characteristics()&ORDERED);
        this.src = src;
        this.f = f;
    }

    private void closeCurr() {
        try { currStream.close(); } finally { currStream=null; curr=null; }
    }

    public void accept(S s) {
        curr=sp(currStream=f.apply(s));
    }

    @Override
    public boolean tryAdvance(Consumer<? super E> action) {
        do {
            if(curr!=null) {
                if(curr.tryAdvance(action))
                    return true;
                closeCurr();
            }
        } while(src.tryAdvance(this));
        return false;
    }

    @Override
    public void forEachRemaining(Consumer<? super E> action) {
        if(curr!=null) {
            curr.forEachRemaining(action);
            closeCurr();
        }
        src.forEachRemaining(s->{
            try(Stream<? extends E> str=f.apply(s)) {
                if(str!=null) str.spliterator().forEachRemaining(action);
            }
        });
    }

    @SuppressWarnings("unchecked")
    private static <X> Spliterator<X> sp(Stream<? extends X> str) {
        return str!=null? ((Stream<X>)str).spliterator(): null;
    }

    @Override
    public Spliterator<E> trySplit() {
        Spliterator<S> split = src.trySplit();
        if(split==null) {
            Spliterator<E> prefix = curr;
            while(prefix==null && src.tryAdvance(s->curr=sp(f.apply(s))))
                prefix=curr;
            curr=null;
            return prefix;
        }
        FlatMappingSpliterator<E,S> prefix=new FlatMappingSpliterator<>(split, f);
        if(curr!=null) {
            prefix.curr=curr;
            curr=null;
        }
        return prefix;
    }
}

All you need for using it, is to add a import static of the flatMap method to your code and change expressions of the form stream.flatmap(function) to flatmap(stream, function).

I.e. in your code

public Stream<String> descendants(String node) {
    return Stream.concat(
        Stream.of(node),
        flatMap(children(node), this::descendants)
    );
}

then you have full lazy behavior. I tested it even with infinite streams…

Note that I added a toggle to allow turning back to the original implementation, e.g. when specifying    -Dstream.flatmap.usestandard=true on the command line.


You're a little bit wrong saying that the flatMap stream is not lazy. It somewhat lazy, though it's laziness is really limited. Let's use some custom Collection to track the requested elements inside your Tree class:

private final Set<String> requested = new LinkedHashSet<>();

private class MyList extends AbstractList<String> implements RandomAccess
{
    private final String[] data;

    public MyList(String... data) {
        this.data = data;
    }

    @Override
    public String get(int index) {
        requested.add(data[index]);
        return data[index];
    }

    @Override
    public int size() {
        return data.length;
    }
}

Now let's pre-initialize your class with some tree data:

public Tree() {
    // "1" is the root note, contains three immediate descendants
    edges.put("1", new MyList("2", "3", "4"));
    edges.put("2", new MyList("5", "6", "7"));
    edges.put("3", new MyList("8", "9", "10"));
    edges.put("8", new MyList("11", "12"));
    edges.put("5", new MyList("13", "14", "15"));
    edges.put("7", new MyList("16", "17", "18"));
    edges.put("6", new MyList("19", "20"));
}

Finally let's check how many elements are actually requested from your list on different limit values:

public static void main(String[] args) {
    for(int i=1; i<=20; i++) {
        Tree tree = new Tree();
        tree.descendants("1").limit(i).toArray();
        System.out.println("Limit = " + i + "; requested = (" + tree.requested.size()
                + ") " + tree.requested);
    }
}

The output is the following:

Limit = 1; requested = (0) []
Limit = 2; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 3; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 4; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 5; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 6; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 7; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 8; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 9; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 10; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 11; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 12; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 13; requested = (12) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18]
Limit = 14; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 15; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 16; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 17; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 18; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 19; requested = (18) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10]
Limit = 20; requested = (19) [2, 5, 13, 14, 15, 6, 19, 20, 7, 16, 17, 18, 3, 8, 11, 12, 9, 10, 4]

Thus when only the root note is requested, no access to children is performed (as Stream.concat is smart). When the first immediate child is requested, the whole subtree for this child is processed even if it's unnecessary. Nevertheless the second immediate child is not processed until the first one finishes. This could be problematic for short-circuiting scenarios, but in most of the cases your terminal operation is not short-circuiting, thus it's still fine approach.

As for your concerns about memory consumption: yes, it eats the memory according to the tree depth (and more importantly it eats the stack). If your tree has thousands nesting levels, you will have the problem with your solution as you may hit StackOverflowError for default -Xss setting. For several hundreds levels of depth it would work fine.

We are using similar approach in business-logic layer of our application, it works fine for us, though our trees are rarely deeper than 10 levels.


Not a real answer, but just a thought:

If you peek into the value collection and on the next step "resolve" that last seen value to a new value collection returning the next values in the same way recursively, then however this is implemented, it will always end up with some kind of "pointer" to the current element in the value collection on the current "level" of depth in the tree, and also with some kind of stack holding all those "pointers".

This is, because you need both the information about the higher levels in the tree (stack) and a "pointer" to the current element at the current level. In this case, one causes the other.

Of course, you can implement this as a Spliterator that holds a Stack of Iterators (pointing to the corresponding value collection), but I suppose there will always be a "pointer" state at each depth level, even if it's hidden in Java's flatMap (or related) temporary objects.

As an alternative: how about using a "real" tree with nodes that hold a reference to its parent node? Plus, adding a map to the root of the tree which holds a reference to all single nodes to simplify the access to a sub-sub-sub-child. I guess the Spliterator implementation would then be really easy because it just needs a reference to the current node for traversing and a stop criteria (the initial node value) to stop walking too "high" up in the tree.