Friday, November 7, 2008

Beautiful folding

> {-# LANGUAGE ExistentialQuantification #-}

> import Data.List (foldl')

If you're not a Haskeller, and were thus hoping to learn how to fold a shirt beautifully, I'm afraid you're out of luck. I don't know either.

Much has been said about writing a Haskell function to calculate the mean of a list of numbers. For example, see Don Stewart's "Write Haskell as fast as C". Basically, one wants to write "nice, declarative" code like this:

> naiveMean :: Fractional a => [a] -> a
> naiveMean xs = sum xs / fromIntegral (length xs)

but if xs is large, sum will bring the whole thing into memory, but the garbage collector won't be able to collect it, since we still need it to calculate the length.

The solution is to calculate both the sum and the length in one pass, and it's usually written something like this:

> uglyMean :: Fractional a => [a] -> a
> uglyMean xs = divide $ foldl' f (P 0 0) xs
> where
> f :: Num a => Pair a Int -> a -> Pair a Int
> f (P s l) x = P (s + x) (l + 1)
> divide (P x y) = x / fromIntegral y

where P is a strict pair constructor. This works, but where is the elegance, abstraction and modularity that Haskell is supposed to be famous for? Don's solution is even uglier (sorry, Don): not only does he write the reductor (our f) explicitly, but also the fold itself.

What I hope to do here is to abstract this pattern away, by making "combinable folds". I only do foldl', although foldl1' could be handy.

To make folds combinable, we need to turn folds into data: a fold is a function (the reductor) with an initial value. To make folds more readily combinable, we add a post-processing function (here it is divide). Now that we have the post-processor, we don't need to look at the accumulator directly, so we make it existential. The type Fold b c is for folds overs lists of type [b], with
results of type c.

> data Fold b c = forall a. F (a -> b -> a) a (a -> c)

We'll need a strict pair type, and I don't want to give my blog a dependency on the strict package, so I introduce my own:

> data Pair a b = P !a !b

Now that folds are data, we can start manipulating them. For example, we can
combine two folds to get a pair of results (we make the result an ordinary tuple for convenience, but use strict pairs for the accumulator to get the rightstrictness). The (***) defined here is like the one in Control.Arrow, but takes a strict pair as input. The reductor (comb f g) is basically (first f) . (second g) for strict pairs.

> both :: Fold b c -> Fold b c' -> Fold b (c, c')
> both (F f x c) (F g y c') = F (comb f g) (P x y) (c *** c')
> where
> comb f g (P a a') b = P (f a b) (g a' b)
> (***) f g (P x y) = (f x, g y)

Our next combinator simply adds an extra post-processor.

> after :: Fold b c -> (c -> c') -> Fold b c'
> after (F f x c) d = F f x (d . c)

The next one, bothWith, is a combination of both and after.

> bothWith :: (c -> c' -> d) -> Fold b c -> Fold b c' -> Fold b d
> bothWith combiner f1 f2 = after (both f1 f2) (uncurry combiner)

Now that we have tools to build folds, we want to actually fold them, so here is combinator foldl':

> cfoldl' :: Fold b c -> [b] -> c
> cfoldl' (F f x c) = c . (foldl' f x)

Now lets see a few basic folds:

> sumF :: Num a => Fold a a
> sumF = F (+) 0 id

> productF :: Num a => Fold a a
> productF = F (*) 1 id

> lengthF :: Fold a Int
> lengthF = F (const . (+1)) 0 id

And, the moment we've all been waiting for, combining basic folds to get the mean of a list:

> meanF :: Fractional a => Fold a a
> meanF = bothWith (/) sumF (after lengthF fromIntegral)

> mean :: Fractional a => [a] -> a
> mean = cfoldl' meanF

Pretty simple, eh? Perhaps not quite as pretty as naiveMean, but best of all, it doesn't eat your memory and kill your swap like naiveMean does.

> main = do
> let xs = [1..10000000]
> print $ mean xs

Compiling with GHC 6.8.2 and -O2, this runs in about 1.2 seconds (on my three-year-old laptop) and uses less than a meg of memory. GHC generates the same code for mean and uglyMean. [Originally uglyMean was slightly faster, but this was because of type defaulting: the result of lengthF defaulted to Integer]

One thing remains. What do Haskellers do when there's a pretty way and a fast way (or at least a way that's more susceptible to optimisation) to do the same thing? We write rewrite rules. So we'd like to convert sum, length, etc. into combinable folds, and then combine them. Something like this:

> {-
> {-# RULES
> "sum/sumF" sum = cfoldl' sumF
> "product/productF" product = cfoldl' productF
> "length/lengthF" length = cfoldl' lengthF
> "multi-cfoldl'" forall c f g xs. c (cfoldl' f xs) (cfoldl' g xs)
> = cfoldl' (bothWith c f g) xs
> #-}
> -}

So why are these commented out? Unfortunately, GHC doesn't like the
all-important "multi-foldl'" rule: it doesn't have a named function at its head (it has the variable c). GHC doesn't allow rules of this form, presumably for efficiency and simplicity in the compiler.

So unfortunately, we can't go back to writing pretty-but-naive code, but with these combinators at our disposal, we are at least saved from writing *ugly* code.


  1. Very interesting idea.

    The after combinator has (almost) the same type as fmap, so you could make Fold an instance of Functor:

    > instance Functor (Fold a) where
    > fmap = flip after

    Similarly, your 'both' is like <*> from Applicative, so you could have an instance of that class as well:

    > instance Applicative (Fold a) where
    > pure x = (\_ _ -> ()) () (\_ -> x)
    > f <*> g = uncurry <$> both f g
    > -- or inline both/fmap in the above

    Now you can write:

    > meanF :: Fractional a => Fold a a
    > meanF = (/) <$> sumF <*> (fromIntegral <$> lengthF)

    Which is very close to what you can write in point free style using the standard Applicative instance for (e ->):

    > mean :: Fractional a => [a] -> a
    > mean = (/) <$> sum <*> (fromIntegral <$> length)

  2. Nice!

    The thing that tickles me about this example, though, is that there's an obvious optimization for this particular benchmark: since the list we're taking the mean of is actually a range, we can calculate its sum using the well-known formula for the sum of an arithmetic progression, and its length by a simple subtraction and division - no traversals are required at all!

    Now, I can see how to take advantage of this possibility automatically in object-oriented languages - make sum() and length() methods of your Enumerable class and override them in the Range class - but I can't see how to do this in functional languages without violating information hiding.

  3. tried compiling with -funbox-strict-fields ? making Fold's fields strict too perhaps

  4. Thank you all for your comments.

    Your D code is very much like my uglyMean, though one uses recursion and the other iteration. In Haskell, it's popular to abstract just about *everything*, including looping. The idea here is to abstract a loop that does multiple things with the same list.

    I did wonder if folds were an instance of an interesting typeclass, but never realised that they were Applicative. Thank you, I'll look into that.

    One option would be to make range types an instance of the Foldable typeclass, and use the more polymorphic Foldable.foldl' instead of List.foldl'.

    Surprisingly, -funbox-strict-fields didn't make enough difference to mention.

  5. Sorry to butt in on your own blog, Quiz, but I wanted to mention something to miles (and feepingcreature indirectly):

    The "range" isn't the tough part. Try doing it with the first million primes, for instance.

    The main point is that if you want the geometric mean instead of the arithmetic, all the functions you need to change are easy to get to, from a conceptual point-of-view.

  6. The reason this is not a problem in an imperative language, of course, is outside state - the counter can be trivially placed outside the loop.

    This also, amusingly, makes it a good example to demonstrate that pure functions aren't the panacea of language design that FP aficionados make it out to be.

  7. BlackMeph: well, yeah, I thought that was kinda obvious. To find the mean of the first million primes you do indeed need to generate and traverse the list (unless there's some clever piece of number theory I don't know about...). But in this specific case there's a fantastic optimisation we can apply to give us constant-time performance, and it would be nice to be able to take advantage of it without sacrificing genericity.

  8. Nice.

    One tiny thing I found confusing is your `after` function

    f . g reads like "f after g"

    but in your case the code

    after f g == f 'after' g is actually

    "f before g" :-)