I have a AST of elementary math arithmetic expressions:
data Expr = Constant Int
| Variable String
| Add Expr Expr
| Multiply Expr Expr
deriving (Show)
I also have a really simple function which simplify given expression:
simplify :: Expr -> Expr
simplify (Add (Constant 0) e) = simplify e
simplify (Add e (Constant 0)) = simplify e
simplify (Add (Constant a) (Constant b)) = Constant (a + b)
simplify (Add e1 e2) = Add (simplify e1) (simplify e2)
simplify (Multiply (Constant 0) _) = Constant 0
simplify (Multiply _ (Constant 0)) = Constant 0
simplify (Multiply (Constant 1) e) = e
simplify (Multiply e (Constant 1)) = e
simplify (Multiply (Constant a) (Constant b)) = Constant (a * b)
simplify (Multiply e1 e2) = Multiply (simplify e1) (simplify e2)
simplify e = e
Unfortunately, this function is not very effective, because it simplify expression from root to leafs (from top to bottom). Consider this expression:
exampleExpr :: Expr
exampleExpr = Add
(Multiply (Constant 1) (Variable "redrum"))
(Multiply (Constant 0) (Constant 451))
It cost two function calls (simplify (simplify exampleExpr)) to reduce this expression into Variable "redrum". With bottom up approach, it should cost only one function call.
I'm not experienced enough yet to be able to write this code effectively. So my question is: how to rewrite this function to simplify given expression from leafs to root (bottom to top)?
Firstly, you're missing a couple of recursive calls. In these lines:
simplify (Multiply (Constant 1) e) = e
simplify (Multiply e (Constant 1)) = e
You should replace the right-hand side with simplify e.
simplify (Multiply (Constant 1) e) = simplify e
simplify (Multiply e (Constant 1)) = simplify e
Now to rewrite the expression from the bottom up. The problem is that you're looking for simplification patterns on the left-hand side of your equation, ie, before you simplify the children. You need to simplify the children first, and then look for the pattern.
simplify :: Expr -> Expr
simplify (Add x y) =
case (simplify x, simplify y) of
(Constant 0, e) -> e
(e, Constant 0) -> e
(Constant a, Constant b) -> Constant (a + b)
(x1, y1) -> Add x1 y1
simplify (Multiply x y) =
case (simplify x, simplify y) of
(Constant 0, _) -> Constant 0
(_, Constant 0) -> Constant 0
(Constant 1, e) -> e
(e, Constant 1) -> e
(Constant a, Constant b) -> Constant (a * b)
(x1, y1) -> Multiply x1 y1
simplify e = e
On the left-hand side of the equation, we find the children of the current node. On the right, we look for patterns in the simplified children. One way of improving this code is to separate the two responsibilities of finding-and-replacing children and of matching simplification patterns. Here's a general function to recursively replace every subtree of an Expr:
transform :: (Expr -> Expr) -> Expr -> Expr
transform f (Add x y) = f $ Add (transform f x) (transform f y)
transform f (Multiply x y) = f $ Multiply (transform f x) (transform f y)
transform f e = f e
transform takes a (non-recursive) transformation function which calculates a replacement for a single-node pattern, and recursively applies it to every node in the tree in a bottom-up manner. To write a transformation function, you just look for the interesting patterns and forget about recursively rewriting the children.
simplify = transform f
where
f (Add (Constant 0) e) = e
f (Add e (Constant 0)) = e
f (Add (Constant a) (Constant b)) = Constant (a + b)
f (Multiply (Constant 0) _) = Constant 0
f (Multiply _ (Constant 0)) = Constant 0
f (Multiply (Constant 1) e) = e
f (Multiply e (Constant 1)) = e
f (Multiply (Constant a) (Constant b)) = Constant (a * b)
f e = e
Since f's argument has already had its children rewritten by transform, we don't need to exhaustively match every possible pattern or explicitly recurse through the value. We look for the ones we care about, and nodes which don't need transforming fall through to the catch-all f e = e case.
Generic programming libraries like lens's Plated module take programming patterns like transform and make them universal. You (or the compiler) write a small amount of code characterising the shape of your datatype, and the library implements recursive higher-order functions like transform once and for all.
Simplifying expression ASTs is a typical application for the recursion scheme called catamorphism. Here is an example with the recursion-schemes library from Edwald Kmett:
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE TemplateHaskell #-}
module CataExprSimplify where
import Data.Functor.Foldable
import Data.Functor.Foldable.TH
data Expr = Constant Int
| Variable String
| Add Expr Expr
| Multiply Expr Expr
deriving (Show)
-- | Generate the base functor
makeBaseFunctor ''Expr
simplify :: Expr -> Expr
simplify = cata $ algSimplAdd . project . algSimplMult
-- | Simplify Addition
simplZero :: Expr -> Expr
simplZero = cata algSimplAdd
algSimplAdd :: ExprF Expr -> Expr
algSimplAdd (AddF (Constant 0) r) = r
algSimplAdd (AddF l (Constant 0)) = l
algSimplAdd (AddF (Constant l) (Constant r)) = Constant (l + r)
algSimplAdd x = embed x
-- | Simplify Multiplication
simplMult :: Expr -> Expr
simplMult = cata algSimplMult
algSimplMult :: ExprF Expr -> Expr
algSimplMult (MultiplyF (Constant 1) r) = r
algSimplMult (MultiplyF l (Constant 1)) = l
algSimplMult (MultiplyF (Constant 0) _) = Constant 0
algSimplMult (MultiplyF _ (Constant 0)) = Constant 0
algSimplMult (MultiplyF (Constant l) (Constant r)) = Constant (l * r)
algSimplMult x = embed x
It has the following advantages over code that uses direct recursion calls:
If you want to read more about recursion schemes read this blog post series
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With