Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Haskell: recursion from leafs to root

Tags:

haskell

I have a AST of elementary math arithmetic expressions:

data Expr = Constant Int
          | Variable String
          | Add Expr Expr
          | Multiply Expr Expr
          deriving (Show)

I also have a really simple function which simplify given expression:

simplify :: Expr -> Expr
simplify (Add (Constant 0) e) = simplify e
simplify (Add e (Constant 0)) = simplify e
simplify (Add (Constant a) (Constant b)) = Constant (a + b)
simplify (Add e1 e2) = Add (simplify e1) (simplify e2)
simplify (Multiply (Constant 0) _) = Constant 0
simplify (Multiply _ (Constant 0)) = Constant 0
simplify (Multiply (Constant 1) e) = e
simplify (Multiply e (Constant 1)) = e
simplify (Multiply (Constant a) (Constant b)) = Constant (a * b)
simplify (Multiply e1 e2) = Multiply (simplify e1) (simplify e2)
simplify e = e

Unfortunately, this function is not very effective, because it simplify expression from root to leafs (from top to bottom). Consider this expression:

exampleExpr :: Expr
exampleExpr = Add
                (Multiply (Constant 1) (Variable "redrum"))
                (Multiply (Constant 0) (Constant 451))

It cost two function calls (simplify (simplify exampleExpr)) to reduce this expression into Variable "redrum". With bottom up approach, it should cost only one function call.

I'm not experienced enough yet to be able to write this code effectively. So my question is: how to rewrite this function to simplify given expression from leafs to root (bottom to top)?

like image 767
user1518183 Avatar asked Mar 03 '26 12:03

user1518183


2 Answers

Firstly, you're missing a couple of recursive calls. In these lines:

simplify (Multiply (Constant 1) e) = e
simplify (Multiply e (Constant 1)) = e

You should replace the right-hand side with simplify e.

simplify (Multiply (Constant 1) e) = simplify e
simplify (Multiply e (Constant 1)) = simplify e

Now to rewrite the expression from the bottom up. The problem is that you're looking for simplification patterns on the left-hand side of your equation, ie, before you simplify the children. You need to simplify the children first, and then look for the pattern.

simplify :: Expr -> Expr
simplify (Add x y) =
    case (simplify x, simplify y) of
        (Constant 0, e) -> e
        (e, Constant 0) -> e
        (Constant a, Constant b) -> Constant (a + b)
        (x1, y1) -> Add x1 y1
simplify (Multiply x y) =
    case (simplify x, simplify y) of
        (Constant 0, _) -> Constant 0
        (_, Constant 0) -> Constant 0
        (Constant 1, e) -> e
        (e, Constant 1) -> e
        (Constant a, Constant b) -> Constant (a * b)
        (x1, y1) -> Multiply x1 y1
simplify e = e

On the left-hand side of the equation, we find the children of the current node. On the right, we look for patterns in the simplified children. One way of improving this code is to separate the two responsibilities of finding-and-replacing children and of matching simplification patterns. Here's a general function to recursively replace every subtree of an Expr:

transform :: (Expr -> Expr) -> Expr -> Expr
transform f (Add x y) = f $ Add (transform f x) (transform f y)
transform f (Multiply x y) = f $ Multiply (transform f x) (transform f y)
transform f e = f e

transform takes a (non-recursive) transformation function which calculates a replacement for a single-node pattern, and recursively applies it to every node in the tree in a bottom-up manner. To write a transformation function, you just look for the interesting patterns and forget about recursively rewriting the children.

simplify = transform f
    where
        f (Add (Constant 0) e) = e
        f (Add e (Constant 0)) = e
        f (Add (Constant a) (Constant b)) = Constant (a + b)
        f (Multiply (Constant 0) _) = Constant 0
        f (Multiply _ (Constant 0)) = Constant 0
        f (Multiply (Constant 1) e) = e
        f (Multiply e (Constant 1)) = e
        f (Multiply (Constant a) (Constant b)) = Constant (a * b)
        f e = e

Since f's argument has already had its children rewritten by transform, we don't need to exhaustively match every possible pattern or explicitly recurse through the value. We look for the ones we care about, and nodes which don't need transforming fall through to the catch-all f e = e case.

Generic programming libraries like lens's Plated module take programming patterns like transform and make them universal. You (or the compiler) write a small amount of code characterising the shape of your datatype, and the library implements recursive higher-order functions like transform once and for all.

like image 88
Benjamin Hodgson Avatar answered Mar 05 '26 07:03

Benjamin Hodgson


Simplifying expression ASTs is a typical application for the recursion scheme called catamorphism. Here is an example with the recursion-schemes library from Edwald Kmett:

{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DeriveFunctor #-} 
{-# LANGUAGE DeriveFoldable #-} 
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE TemplateHaskell #-}

module CataExprSimplify where

import Data.Functor.Foldable
import Data.Functor.Foldable.TH

data Expr = Constant Int
          | Variable String
          | Add Expr Expr
          | Multiply Expr Expr
          deriving (Show)

-- | Generate the base functor
makeBaseFunctor ''Expr 

simplify :: Expr -> Expr 
simplify = cata $ algSimplAdd . project . algSimplMult

-- | Simplify Addition
simplZero :: Expr -> Expr 
simplZero = cata algSimplAdd

algSimplAdd :: ExprF Expr -> Expr 
algSimplAdd (AddF (Constant 0) r) = r 
algSimplAdd (AddF l (Constant 0)) = l 
algSimplAdd (AddF (Constant l) (Constant r)) = Constant (l + r)
algSimplAdd x  = embed x

-- | Simplify Multiplication
simplMult :: Expr -> Expr 
simplMult = cata algSimplMult

algSimplMult :: ExprF Expr -> Expr 
algSimplMult (MultiplyF (Constant 1) r) = r
algSimplMult (MultiplyF l (Constant 1)) = l 
algSimplMult (MultiplyF (Constant 0) _) = Constant 0
algSimplMult (MultiplyF _ (Constant 0)) = Constant 0 
algSimplMult (MultiplyF (Constant l) (Constant r)) = Constant (l * r)
algSimplMult x = embed x

It has the following advantages over code that uses direct recursion calls:

  • Recursion is abstracted out to the cata function and not intertwinned with your simplification logic.
  • You will not forget to call simplify on subexpressions.
  • Catamorphisms work from bottom to top.
  • Simplification of Addition and Multiplication can be written in separate functions.
  • it's much easier to maintain your code if you have to extend your AST (eg adding new constructors)

If you want to read more about recursion schemes read this blog post series

like image 38
Jogger Avatar answered Mar 05 '26 06:03

Jogger



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!