This example, which follows an example from Church, demonstrates how the abstract types of the Chinese Restaurant Process can be used to program an infinite relational model.
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ExtendedDefaultRules #-}
module IrmDemo where
import LazyPPL
import LazyPPL.Distributions
import LazyPPL.Distributions.DirichletP
import LazyPPL.Distributions.Memoization
import Data.List
import Graphics.Matplotlib
We have six people, and we know that some of them talk to each other, and some of them don’t talk to each other. We want to infer the social groups. This is non-parametric in that we don’t assume a fixed number of social groups.
We set up a data type inhabited by the people of interest.
data Person = Tom | Fred | Jim | Mary | Sue | Ann deriving (Show , Eq, Ord, Enum, Bounded)
instance MonadMemo Prob Person
The model is set up by building a Chinese Restaurant and placing the
people at tables in it. Here we are using our Chinese Restaurant process
interface (Distr.DirichletP
). It involves abstract types
Restaurant
and
Table
,
and provides two functions:
newRestaurant :: Double -> Prob Restaurant
,
which provides a new restaurant;newCustomer :: Restaurant -> Prob Table
,
which says which table a new customer will sit at.The model describes an unnormalized probability measure on functions assigning tables to people.
model :: Meas (Person -> Table)
= do model
We first set up a new restaurant. Behind the scenes, this initiates lazy stick-breaking.
r :: Restaurant <- sample $ newRestaurant 1.0
We define two memoized functions: the first, table
,
assigns a table to each person. Memoization is defined using
laziness.
table :: (Person -> Table) <- sample $ memoize $ \person -> newCustomer r
The second memoized function, near
, assigns to each pair
of tables the chance that people on those tables will talk to each
other.
near :: ((Table, Table) -> Double) <- sample $ memoize $ \(tableA, tableB) -> beta 0.5 0.5
We define two helper functions, talks
and
nottalks
, which we then map over the observations about
various people talking or not talking to each other.
let talks :: (Person, Person) -> Meas () = \(personA, personB) ->
$ near (table personA, table personB)
score let nottalks :: (Person, Person) -> Meas () = \(personA, personB) ->
$ 1 - near (table personA, table personB) score
The data set:
mapM_ talks [(Tom, Fred), (Tom, Jim), (Jim, Fred), (Mary, Sue), (Mary, Ann), (Ann, Sue)]
mapM_ nottalks [(Mary, Fred), (Mary, Jim), (Sue, Fred), (Sue, Tom), (Ann, Jim), (Ann, Tom)]
Finally we return the assignment of tables to people.
return table
We sample from this unnormalized measure using a Metropolis-Hastings simulation. We calculate the probability of Tom/Fred and Tom/Mary sitting together, and also plot a graph of the MAP sample.
= do
main <- take 10000 <$> mh 0.2 model
tws "images/irm-tom-fred.svg" $ map (\(t,_) -> t Tom == t Fred) $ tws
plotHistogram "images/irm-tom-mary.svg" $ map (\(t,_) -> t Tom == t Mary) $ tws
plotHistogram writeFile "images/irm-tables.dot" $ show $ maxap tws
The example at Probmods actually gives different histograms to the
ones here, but we suspect that this is an issue with the
mh-query
parameters in that example, because webchurch’s
rejection sampling agrees with our histograms.
-- Maximum a priori from a list of weighted samples
=
maxap xws let maxw = (maximum $ map snd xws) in
let (Just x) = Data.List.lookup maxw $
map (\(z, w) -> (w, z)) xws in
x
tableToDot :: (Person -> Table) -> String
= "graph tables {" ++ concat [ show a ++ " -- " ++ show b ++ "; " | a <- people , b <- people , a < b, f a == f b] ++ "}"
tableToDot f where dotLine a b True = show a ++ " -- " ++ show b ++ "\n"
= [minBound..maxBound]
people instance Show (Person -> Table) where show f = tableToDot f
plotHistogram :: (Show a , Eq a) => String -> [a] -> IO ()
= do
plotHistogram filename xs putStrLn $ "Generating " ++ filename ++ "..."
let categories = nub xs
let counts = map (\c -> length $ filter (==c) xs) categories
$ bar (map show categories) $ map (\n -> (fromIntegral n)/(fromIntegral $ length xs)) counts
file filename putStrLn $ "Done."