Learn You a Haskell presents the Prob newtype:
newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving Show
Here's Prob's definitions:
instance Functor Prob where
fmap f (Prob xs) = Prob $ map (\(x,p) -> (f x,p)) xs
instance Monad Prob where
return x = Prob [(x, 1%1)]
p >>= f = flatten (fmap f p)
And then the supporting functions:
flatten :: Prob (Prob a) -> Prob a
flatten = Prob . convert . getProb
convert :: [(Prob a, Rational)] -> [(a, Rational)]
convert = concat . (map f)
f :: (Prob a, Rational) -> [(a, Rational)]
f (p, r) = map (mult r) (getProb p)
mult :: Rational -> (a, Rational) -> (a, Rational)
mult r (x, y) = (x, r*y)
I wrote the flatten, convert, f, and mult functions, so I'm comfortable with them.
Then we apply >>= to the following example, involving a data type, Coin:
data Coin = Heads | Tails deriving (Show, Eq)
coin :: Prob Coin
coin = Prob [(Heads, 1%2), (Tails, 1%2)]
loadedCoin :: Prob Coin
loadedCoin = Prob [(Heads, 1%10), (Tails, 9%10)]
LYAH says, If we throw all the coins at once, what are the odds of all of them landing tails?
flipTwo:: Prob Bool
flipTwo= do
a <- coin -- a has type `Coin`
b <- loadedCoin -- similarly
return (all (== Tails) [a,b])
Calling flipTwo returns:
Prob {getProb = [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]}
flipTwo can be re-written with >>=:
flipTwoBind' :: Prob Bool
flipTwoBind' = coin >>=
\x -> loadedCoin >>=
\y -> return (all (== Tails) [x,y])
I'm not understanding the type of return (all (== Tails) [x,y]). Since it's the right-hand side of >>=, then its type must be a -> m b (where Monad m).
My understanding is that (all (==Tails) [x,y]) returns True or False, but how does return lead to the above result:
Prob {getProb = [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]}?
Note that the RHS of the >>= operator is a lambda expression, not the application of return:
\y -> return (all (== Tails) [x,y])
This lambda has type (Monad m) => a -> m b as expected.
Let's build up the type from the bottom:
As you say, all (== Tails) [x,y] returns True or False. In otherwords, its type is Bool.
Now, checking the type of return in ghci, we see that is:
Prelude> :t return
return :: Monad m => a -> m a
So return (all (==Tails) [x,y]) is type Monad m => m Boolean.
Wrapping this in a lambda, then gives the type (Monad m) => a -> m Boolean.
(Note that somewhere along the way, the compiler will deduce that the concrete monad type is Prob.)
You should think of return as taking a regular value and wrapping it into a Monad.
Addition:
Let's analyze the type of
flipTwoBind' = coin >>=
\x -> loadedCoin >>=
\y -> return (all (== Tails) [x,y])
We start by noting that the outermost expression here is an application of (>>=) which has type:
Prelude> :t (>>=)
(>>=) :: Monad m => m a -> (a -> m b) -> m b
The LHS is coin which has type Prob Coin, so we immediately deduce that m is Prob and a is Coin. This means that the RHS must have type Coin -> Prob b for some type b. So let's look at the RHS now:
\x -> loadedCoin >>= \y -> return (all (== Tails) [x,y])
Here we have a lambda that returns the result of an application of (>>=), so the lambda has type
(Monad m) => a -> m b
This matches the expected type for the application of the first (>>=), so a here is Coin and m is Prob.
Now analyzing the inner application of (>>=), we see that its type is deduced to be
(>>=) :: Prob Coin -> (Prob -> Prob b) -> Prob b
We already analyzed the RHS of the second (>>=), and so b is deduced to be Bool.
(Note, this may not be the exact order that the compiler uses to deduce the types. It just happens to be the order which my thoughts followed as I analyzed the types for this answer.)
(I'll call your coin fairCoin) You have:
flipTwoBind' :: Prob Bool
flipTwoBind' = fairCoin >>= g where
g x = loadedCoin >>= h where
h y = return z where
z = all (== Tails) [x,y]
From the type of (>>=) we get:
fairCoin :: Prob Coin
(>>=) :: Monad m => m a -> (a -> m b) -> m b | m ~ Prob, a ~ Coin
fairCoin >>= g :: m b | g :: Coin -> Prob b
flipTwoBind' :: Prob Bool | m ~ Prob, b ~ Bool
so that g :: Coin -> Prob Bool and g x :: Prob Bool provided that x :: Coin.
Since g x = loadedCoin >>= h, we have
loadedCoin :: Prob Coin
(>>=) :: Monad m => m a -> (a -> m b) -> m b
loadedCoin >>= h :: Prob Bool
So, h :: Coin -> Prob Bool, z :: Bool and return z :: Prob Bool:
all :: (a -> Bool) -> [a] -> Bool
all p [] :: Bool
return :: (Monad m) => a -> m a
z :: Bool
return z :: m Bool | m ~ Prob so return z :: Prob Bool
Since Prob a is essentially a tagged assoc-list of pairs of a outcomes and their corresponding probabilities, Prob Bool is a list of pairings of Bool outcomes and their probabilities.
Translated with the specific Prob monadic code, inlining all the functions, flipTwoBind' becomes
flipTwoBind' = fairCoin >>= g
= flatten (fmap g fairCoin)
= Prob . convert . getProb $
Prob $ map (\(x,p) -> (g x,p)) $ getProb fairCoin
= Prob . concat . map (\(x,p) -> map (\(x, y) -> (x, p*y)) $ getProb x)
. map (\(x,p) -> (g x,p)) $ getProb fairCoin
(see how nicely the Prob and getProb cancel each other there on the inside...).
Switching to plain list-based code (with gL xs = getProb (g (Prob xs)) and fairCoinL = getProb fairCoin etc.), it is equivalent to
= concat . map (\(x,p) -> map (second (p*)) x)
. map (\(x,p) -> (gL x,p)) $ fairCoinL
= concat . map (\(x,p) -> map (second (p*)) $ gL x) $ fairCoinL
= [(v,p*q) | (x,p) <- fairCoinL, (v,q) <- gL x]
= ....
= [(z,r) | (x,p) <- [(Heads, 1%2), (Tails, 1%2 )], -- do a <- fairCoin
(y,q) <- [(Heads, p*1%10), (Tails, p*9%10)], -- b <- loadedCoin
(z,r) <- [(all (== Tails) [x,y], q*1%1 )] ] -- return ... all ...
= [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]
Of course the one before last line in the derivation above could equally be just written as
= [(all (== Tails) [x,y], q) -- ... all ... <$>
| (x,p) <- [(Heads, 1%2), (Tails, 1%2 )], -- fairCoin <*>
(y,q) <- [(Heads, p*1%10), (Tails, p*9%10)] ] -- loadedCoin
because (>>= return . f) === fmap f.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With