Efficient graph traversal with LINQ - eliminating recursion

Today I was going to implement a method to traverse an arbitrarily deep graph and flatten it into a single enumerable. Instead, I did a little searching first and found this:

public static IEnumerable<T> Traverse<T>(this IEnumerable<T> enumerable, Func<T, IEnumerable<T>> recursivePropertySelector)
{
    foreach (T item in enumerable)
    {
        yield return item;

        IEnumerable<T> seqRecurse = recursivePropertySelector(item);

        if (seqRecurse == null) continue;
        foreach (T itemRecurse in Traverse(seqRecurse, recursivePropertySelector))
        {
            yield return itemRecurse;
        }
    }
}

In theory this looks good, but in practice I've found it performs significantly more poorly than using equivalent hand-written code (as the situation arises) to go through a graph and do whatever needs to be done. I suspect this is because in this method, for every item it returns, the stack has to unwind to some arbitrarily deep level.

I also suspect that this method would run a lot more efficiently if the recursion were eliminated. I also happen to not be very good at eliminating recursion.

Does anyone know how to rewrite this method to eliminate the recursion?

Thanks for any help.

EDIT: Thanks very much for all the detailed responses. I've tried benchmarking the original solution vs Eric's solution vs not using an enumerator method and instead recursively traversing with a a lambda and oddly, the lambda recursion is significantly faster than either of the other two methods.

class Node
{
    public List<Node> ChildNodes { get; set; } 

    public Node()
    {
        ChildNodes = new List<Node>();
    }
}

class Foo
{
    public static void Main(String[] args) 
    {
        var nodes = new List<Node>();
        for(int i = 0; i < 100; i++)
        {
            var nodeA = new Node();
            nodes.Add(nodeA);
            for (int j = 0; j < 100; j++)
            {
                var nodeB = new Node();
                nodeA.ChildNodes.Add(nodeB);
                for (int k = 0; k < 100; k++)
                {
                    var nodeC = new Node();
                    nodeB.ChildNodes.Add(nodeC);
                    for(int l = 0; l < 12; l++)
                    {
                        var nodeD = new Node();
                        nodeC.ChildNodes.Add(nodeD);
                    }
                }
            }
        }            

        nodes.TraverseOld(node => node.ChildNodes).ToList();
        nodes.TraverseNew(node => node.ChildNodes).ToList();

        var watch = Stopwatch.StartNew();
        nodes.TraverseOld(node => node.ChildNodes).ToList();
        watch.Stop();
        var recursiveTraversalTime = watch.ElapsedMilliseconds;
        watch.Restart();
        nodes.TraverseNew(node => node.ChildNodes).ToList();
        watch.Stop();
        var noRecursionTraversalTime = watch.ElapsedMilliseconds;

        Action<Node> visitNode = null;
        visitNode = node =>
        {
            foreach (var child in node.ChildNodes)
                visitNode(child);
        };

        watch.Restart();
        foreach(var node in nodes)
            visitNode(node);
        watch.Stop();
        var lambdaRecursionTime = watch.ElapsedMilliseconds;
    }
}

Where TraverseOld is the original method, TraverseNew is Eric's method, and obviously the lambda is the lambda.

On my machine, TraverseOld takes 10127 ms, TraverseNew takes 3038 ms, the lambda recursion takes 1181 ms.

Is this typical that enumerator methods (with yield return) can take 3X as long as opposed to immediate execution? Or is something else going on here?


Solution 1:

First off, you are absolutely correct; if the graph has n nodes of average depth d then the naive nested iterators yield a solution which is O(n*d) in time, and O(d) in stack. If d is a large fraction of n then this can become an O(n2) algorithm, and if d is large then you can blow the stack entirely.

If you are interested in a performance analysis of nested iterators, see former C# compiler developer Wes Dyer's blog post:

http://blogs.msdn.microsoft.com/wesdyer/2007/03/23/all-about-iterators

dasblinkenlight's solution is a variation on the standard approach. I would typically write the program like this:

public static IEnumerable<T> Traverse<T>(
    T root, 
    Func<T, IEnumerable<T>> children)
{
    var stack = new Stack<T>();
    stack.Push(root);
    while(stack.Count != 0)
    {
        T item = stack.Pop();
        yield return item;
        foreach(var child in children(item))
            stack.Push(child);
    }
}

And then if you have multiple roots:

public static IEnumerable<T> Traverse<T>(
    IEnumerable<T> roots, 
    Func<T, IEnumerable<T>> children)
{
    return from root in roots 
           from item in Traverse(root, children)
           select item ;
}

Now, note that a traversal is not what you want if you have a highly interconnected graph or a cyclic graph! If you have a graph with downward pointing arrows:

          A
         / \
        B-->C
         \ /
          D

then the traversal is A, B, D, C, D, C, D. If you have a cyclic or interconnected graph then what you want is the transitive closure.

public static IEnumerable<T> Closure<T>(
    T root, 
    Func<T, IEnumerable<T>> children)
{
    var seen = new HashSet<T>();
    var stack = new Stack<T>();
    stack.Push(root);

    while(stack.Count != 0)
    {
        T item = stack.Pop();
        if (seen.Contains(item))
            continue;
        seen.Add(item);
        yield return item;
        foreach(var child in children(item))
            stack.Push(child);
    }
}

This variation only yields items that have not been yielded before.

I also happen to not be very good at eliminating recursion.

I have written a number of articles on ways to eliminate recursion, and about recursive programming in general. If this subject interests you, see:

http://blogs.msdn.com/b/ericlippert/archive/tags/recursion/

In particular:

http://blogs.msdn.com/b/ericlippert/archive/2005/08/01/recursion-part-two-unrolling-a-recursive-function-with-an-explicit-stack.aspx

http://blogs.msdn.com/b/ericlippert/archive/2005/08/04/recursion-part-three-building-a-dispatch-engine.aspx

http://blogs.msdn.com/b/ericlippert/archive/2005/08/08/recursion-part-four-continuation-passing-style.aspx

Solution 2:

You are right, walking trees and graphs recursively in code that does yield return is a big source of inefficiency.

Generally, you rewrite recursive code with a stack - in a similar way to how it is usually implemented in compiled code.

I did not get a chance to try it out, but this should work:

public static IEnumerable<T> Traverse<T>(this IEnumerable<T> enumerable, Func<T, IEnumerable<T>> recursivePropertySelector) {
    var stack = new Stack<IEnumerable<T>>();
    stack.Push(enumerable);
    while (stack.Count != 0) {
        enumerable = stack.Pop();
        foreach (T item in enumerable) {
            yield return item;
            var seqRecurse = recursivePropertySelector(item);
            if (seqRecurse != null) {
                stack.Push(seqRecurse);
            }
        }
    }
}