Tail recursive function to find depth of a tree in Ocaml
You can trivially do this by turning the function into CPS (Continuation Passing Style). The idea is that instead of calling depth left
, and then computing things based on this result, you call depth left (fun dleft -> ...)
, where the second argument is "what to compute once the result (dleft
) is available".
let depth tree =
let rec depth tree k = match tree with
| Leaf x -> k 0
| Node(_,left,right) ->
depth left (fun dleft ->
depth right (fun dright ->
k (1 + (max dleft dright))))
in depth tree (fun d -> d)
This is a well-known trick that can make any function tail-recursive. Voilà, it's tail-rec.
The next well-known trick in the bag is to "defunctionalize" the CPS result. The representation of continuations (the (fun dleft -> ...)
parts) as functions is neat, but you may want to see what it looks like as data. So we replace each of these closures by a concrete constructor of a datatype, that captures the free variables used in it.
Here we have three continuation closures: (fun dleft -> depth right (fun dright -> k ...))
, which only reuses the environment variables right
and k
, (fun dright -> ...)
, which reuses k
and the now-available left result dleft
, and (fun d -> d)
, the initial computation, that doesn't capture anything.
type ('a, 'b) cont =
| Kleft of 'a tree * ('a, 'b) cont (* right and k *)
| Kright of 'b * ('a, 'b) cont (* dleft and k *)
| Kid
The defunctorized function looks like this:
let depth tree =
let rec depth tree k = match tree with
| Leaf x -> eval k 0
| Node(_,left,right) ->
depth left (Kleft(right, k))
and eval k d = match k with
| Kleft(right, k) ->
depth right (Kright(d, k))
| Kright(dleft, k) ->
eval k (1 + max d dleft)
| Kid -> d
in depth tree Kid
;;
Instead of building a function k
and applying it on the leaves (k 0
), I build a data of type ('a, int) cont
, which needs to be later eval
uated to compute a result. eval
, when it gets passed a Kleft
, does what the closure (fun dleft -> ...)
was doing, that is it recursively call depth
on the right subtree. eval
and depth
are mutually recursive.
Now look hard at ('a, 'b) cont
, what is this datatype? It's a list!
type ('a, 'b) next_item =
| Kleft of 'a tree
| Kright of 'b
type ('a, 'b) cont = ('a, 'b) next_item list
let depth tree =
let rec depth tree k = match tree with
| Leaf x -> eval k 0
| Node(_,left,right) ->
depth left (Kleft(right) :: k)
and eval k d = match k with
| Kleft(right) :: k ->
depth right (Kright(d) :: k)
| Kright(dleft) :: k ->
eval k (1 + max d dleft)
| [] -> d
in depth tree []
;;
And a list is a stack. What we have here is actually a reification (transformation into data) of the call stack of the previous recursive function, with two different cases corresponding to the two different kinds of non-tailrec calls.
Note that the defunctionalization is only there for fun. In pratice the CPS version is short, easy to derive by hand, rather easy to read, and I would recommend using it. Closures must be allocated in memory, but so are elements of ('a, 'b) cont
-- albeit those might be represented more compactly`. I would stick to the CPS version unless there are very good reasons to do something more complicated.
In this case (depth computation), you can accumulate over pairs (subtree depth
* subtree content
) to obtain the following tail-recursive function:
let depth tree =
let rec aux depth = function
| [] -> depth
| (d, Leaf _) :: t -> aux (max d depth) t
| (d, Node (_,left,right)) :: t ->
let accu = (d+1, left) :: (d+1, right) :: t in
aux depth accu in
aux 0 [(0, tree)]
For more general cases, you will indeed need to use the CPS transformation described by Gabriel.