I am trying to speed up the following function:
{-# LANGUAGE BangPatterns #-}
import Data.Word
import Data.Bits
import Data.List (foldl1')
import System.Random
import qualified Data.List as L
data Tree a = AB (Tree a) (Tree a) | A (Tree a) | B (Tree a) | C !a
    deriving Show
merge :: Tree a -> Tree a -> Tree a
merge (C x) _               = C x
merge _ (C y)               = C y
merge (A ta) (A tb)         = A (merge ta tb)
merge (A ta) (B tb)         = AB ta tb
merge (A ta) (AB tb tc)     = AB (merge ta tb) tc
merge (B ta) (A tb)         = AB tb ta
merge (B ta) (B tb)         = B (merge ta tb)
merge (B ta) (AB tb tc)     = AB tb (merge ta tc)
merge (AB ta tb) (A tc)     = AB (merge ta tc) tb
merge (AB ta tb) (B tc)     = AB ta (merge tb tc)
merge (AB ta tb) (AB tc td) = AB (merge ta tc) (merge tb td)
In order to stress its performance, I've implemented sort using merge:
fold ab a b c list = go list where
    go (AB a' b') = ab (go a') (go b')
    go (A a')     = a (go a')
    go (B b')     = b (go b')
    go (C x)      = c x
mergeAll :: [Tree a] -> Tree a
mergeAll = foldl1' merge
foldrBits :: (Word32 -> t -> t) -> t -> Word32 -> t
foldrBits cons nil word = go 32 word nil where
    go 0 w !r = r
    go l w !r = go (l-1) (shiftR w 1) (cons (w.&.1) r)
word32ToTree :: Word32 -> Tree Word32
word32ToTree w = foldrBits cons (C w) w where
    cons 0 t = A t
    cons 1 t = B t
toList = fold (++) id id (\ a -> [a])
sort = toList . mergeAll . map word32ToTree
main = do
    is <- mapM (const randomIO :: a -> IO Word32) [0..500000]
    print $ sum $ sort is
The performance came up quite good from the go, around 2.5x slower than Data.List's sort. Nothing that I did sped that up further, though: inlining several functions, bang annotations in many places, UNPACK on C !a. Is there any way to speed this function up? 
The easiest way to make a function faster is to let it do less work. One way to do that is use a function tailored to a more specific type of input or output, or to a more specific problem.
You definitely have too many thunks allocated. I'll show how to analyze the code:
merge (A ta) (A tb)         = A (merge ta tb)
Here you allocate constructor A with one argument, which is a thunk. Can you say when the merge ta tb chunk will be forced? Probably only at the very end, when the resulting tree is used. Try to add a bang to each argument of each Tree constructor to ensure it is spine-strict:
data Tree a = AB !(Tree a) !(Tree a) | A !(Tree a) | B !(Tree a) | C !a
The next example:
go l w !r = go (l-1) (shiftR w 1) (cons (w.&.1) r)
Here you are allocating a thunk for l-1, shifrR w 1 and cons (w.&.1) r. The first one will be forced on the next iterations when comparing l with o, the second one will be forced when forcing the 3d thunk in the next iteration (w is used here), and the 3rd thunk is forced on the next iteration because of a bang on r. So probably this particular clause it OK.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With