Memoization in Haskell?
Solution 1:
We can do this very efficiently by making a structure that we can index in sub-linear time.
But first,
{-# LANGUAGE BangPatterns #-}
import Data.Function (fix)
Let's define f
, but make it use 'open recursion' rather than call itself directly.
f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
mf (n `div` 3) +
mf (n `div` 4)
You can get an unmemoized f
by using fix f
This will let you test that f
does what you mean for small values of f
by calling, for example: fix f 123 = 144
We could memoize this by defining:
f_list :: [Int]
f_list = map (f faster_f) [0..]
faster_f :: Int -> Int
faster_f n = f_list !! n
That performs passably well, and replaces what was going to take O(n^3) time with something that memoizes the intermediate results.
But it still takes linear time just to index to find the memoized answer for mf
. This means that results like:
*Main Data.List> faster_f 123801
248604
are tolerable, but the result doesn't scale much better than that. We can do better!
First, let's define an infinite tree:
data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)
And then we'll define a way to index into it, so we can find a node with index n
in O(log n) time instead:
index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
(q,0) -> index l q
(q,1) -> index r q
... and we may find a tree full of natural numbers to be convenient so we don't have to fiddle around with those indices:
nats :: Tree Int
nats = go 0 1
where
go !n !s = Tree (go l s') n (go r s')
where
l = n + s
r = l + s
s' = s * 2
Since we can index, you can just convert a tree into a list:
toList :: Tree a -> [a]
toList as = map (index as) [0..]
You can check the work so far by verifying that toList nats
gives you [0..]
Now,
f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats
fastest_f :: Int -> Int
fastest_f = index f_tree
works just like with list above, but instead of taking linear time to find each node, can chase it down in logarithmic time.
The result is considerably faster:
*Main> fastest_f 12380192300
67652175206
*Main> fastest_f 12793129379123
120695231674999
In fact it is so much faster that you can go through and replace Int
with Integer
above and get ridiculously large answers almost instantaneously
*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489
*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
For an out-of-the-box library that implements the tree based memoization, use MemoTrie:
$ stack repl --package MemoTrie
Prelude> import Data.MemoTrie
Prelude Data.MemoTrie> :set -XLambdaCase
Prelude Data.MemoTrie> :{
Prelude Data.MemoTrie| fastest_f' :: Integer -> Integer
Prelude Data.MemoTrie| fastest_f' = memo $ \case
Prelude Data.MemoTrie| 0 -> 0
Prelude Data.MemoTrie| n -> max n (fastest_f'(n `div` 2) + fastest_f'(n `div` 3) + fastest_f'(n `div` 4))
Prelude Data.MemoTrie| :}
Prelude Data.MemoTrie> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
Solution 2:
Edward's answer is such a wonderful gem that I've duplicated it and provided implementations of memoList
and memoTree
combinators that memoize a function in open-recursive form.
{-# LANGUAGE BangPatterns #-}
import Data.Function (fix)
f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
mf (div n 3) +
mf (div n 4)
-- Memoizing using a list
-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
where memoList_f = (memo !!) . fromInteger
memo = map (f memoList_f) [0..]
faster_f :: Integer -> Integer
faster_f = memoList f
-- Memoizing using a tree
data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)
index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
(q,0) -> index l q
(q,1) -> index r q
nats :: Tree Integer
nats = go 0 1
where
go !n !s = Tree (go l s') n (go r s')
where
l = n + s
r = l + s
s' = s * 2
toList :: Tree a -> [a]
toList as = map (index as) [0..]
-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
where memoTree_f = index memo
memo = fmap (f memoTree_f) nats
fastest_f :: Integer -> Integer
fastest_f = memoTree f
Solution 3:
Not the most efficient way, but does memoize:
f = 0 : [ g n | n <- [1..] ]
where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)
when requesting f !! 144
, it is checked that f !! 143
exists, but its exact value is not calculated. It's still set as some unknown result of a calculation. The only exact values calculated are the ones needed.
So initially, as far as how much has been calculated, the program knows nothing.
f = ....
When we make the request f !! 12
, it starts doing some pattern matching:
f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...
Now it starts calculating
f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3
This recursively makes another demand on f, so we calculate
f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0
Now we can trickle back up some
f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1
Which means the program now knows:
f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...
Continuing to trickle up:
f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3
Which means the program now knows:
f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...
Now we continue with our calculation of f!!6
:
f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6
Which means the program now knows:
f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...
Now we continue with our calculation of f!!12
:
f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13
Which means the program now knows:
f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...
So the calculation is done fairly lazily. The program knows that some value for f !! 8
exists, that it's equal to g 8
, but it has no idea what g 8
is.
Solution 4:
This is an addendum to Edward Kmett's excellent answer.
When I tried his code, the definitions of nats
and index
seemed pretty mysterious, so I write an alternative version that I found easier to understand.
I define index
and nats
in terms of index'
and nats'
.
index' t n
is defined over the range [1..]
. (Recall that index t
is defined over the range [0..]
.) It works searches the tree by treating n
as a string of bits, and reading through the bits in reverse. If the bit is 1
, it takes the right-hand branch. If the bit is 0
, it takes the left-hand branch. It stops when it reaches the last bit (which must be a 1
).
index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
(n', 0) -> index' l n'
(n', 1) -> index' r n'
Just as nats
is defined for index
so that index nats n == n
is always true, nats'
is defined for index'
.
nats' = Tree l 1 r
where
l = fmap (\n -> n*2) nats'
r = fmap (\n -> n*2 + 1) nats'
nats' = Tree l 1 r
Now, nats
and index
are simply nats'
and index'
but with the values shifted by 1:
index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
Solution 5:
As stated in Edward Kmett's answer, to speed things up, you need to cache costly computations and be able to access them quickly.
To keep the function non monadic, the solution of building an infinite lazy tree, with an appropriate way to index it (as shown in previous posts) fulfills that goal. If you give up the non-monadic nature of the function, you can use the standard associative containers available in Haskell in combination with “state-like” monads (like State or ST).
While the main drawback is that you get a non-monadic function, you do not have to index the structure yourself anymore, and can just use standard implementations of associative containers.
To do so, you first need to re-write you function to accept any kind of monad:
fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _ 0 = return 0
fm recf n = do
recs <- mapM recf $ div n <$> [2, 3, 4]
return $ max n (sum recs)
For your tests, you can still define a function that does no memoization using Data.Function.fix, although it is a bit more verbose:
noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm
You can then use State monad in combination with Data.Map to speed things up:
import qualified Data.Map.Strict as MS
withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
where
recF i = do
v <- MS.lookup i <$> get
case v of
Just v' -> return v'
Nothing -> do
v' <- fm recF i
modify $ MS.insert i v'
return v'
With minor changes, you can adapt the code to works with Data.HashMap instead:
import qualified Data.HashMap.Strict as HMS
withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
where
recF i = do
v <- HMS.lookup i <$> get
case v of
Just v' -> return v'
Nothing -> do
v' <- fm recF i
modify $ HMS.insert i v'
return v'
Instead of persistent data structures, you may also try mutable data structures (like the Data.HashTable) in combination with the ST monad:
import qualified Data.HashTable.ST.Linear as MHM
withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
do ht <- MHM.new
recF ht n
where
recF ht i = do
k <- MHM.lookup ht i
case k of
Just k' -> return k'
Nothing -> do
k' <- fm (recF ht) i
MHM.insert ht i k'
return k'
Compared to the implementation without any memoization, any of these implementation allows you, for huge inputs, to get results in micro-seconds instead of having to wait several seconds.
Using Criterion as benchmark, I could observe that the implementation with the Data.HashMap actually performed slightly better (around 20%) than that the Data.Map and Data.HashTable for which the timings were very similar.
I found the results of the benchmark a bit surprising. My initial feeling was that the HashTable would outperform the HashMap implementation because it is mutable. There might be some performance defect hidden in this last implementation.