We discuss Bayesian linear regression
and piecewise linear
regression. Our piecewise linear regression uses an infinite Poisson
process as the set of change points. The laziness of Haskell effectively
truncates the infinite process as needed. The examples also demonstrate
that higher-order functions (such as regress
and
splice
) are very useful.
module RegressionDemo where
import LazyPPL
import LazyPPL.Distributions
import Data.Colour
import Data.Colour.Names
import Control.Monad
import Graphics.Matplotlib hiding (density)
Regression is about finding a fitting function to some data. Bayesian regression is about finding a posterior distribution on functions, given the data.
We start with a random linear function:
linear :: Prob (Double -> Double)
=
linear do
<- normal 0 3
a <- normal 0 3
b let f = \x -> a * x + b
return f
=
plotLinearPrior do
<- mh 1 (sample linear)
fs' let fs = map fst $ take 1000 $ fs'
"images/regression-linear-prior.svg" [] fs 0.1 plotFuns
dataset :: [(Double, Double)]
= [(0,0.6), (1, 0.7), (2,1.2), (3,3.2), (4,6.8), (5, 8.2), (6,8.4)] dataset
=
plotDataset do
"images/regression-dataset.svg" dataset [] 0.1 plotFuns
Our regression here is noisy: the function has not precisely generated this data set, because the points are not colinear.
Our generic regression function takes a random functionprior
, and some input/output observations
dataset
, which are assumed to be noisy according to
sigma
, returns a conditioned random linear function
(unnormalized).
regress :: Double -> Prob (a -> Double) -> [(a, Double)] -> Meas (a -> Double)
=
regress sigma prior dataset do
<- sample prior
f -> score $ normalPdf (f x) sigma y)
forM_ dataset (\(x, y) return f
=
plotLinReg do fs' <- mh 0.5 (regress 0.5 linear dataset)
let fs = map fst $ take 1000 $ every 50 $ drop 100 fs'
"images/regression-linear-reg.svg" dataset fs 0.01 plotFuns
splice
which splices together different draws from a random
function at a random selection of change points. NB if the point process
is infinite then the resulting function has an infinite number of
pieces, but this is all done lazily, so it’s not a problem.
splice :: Prob [Double] -> Prob (Double -> Double) -> Prob (Double -> Double)
=
splice pointProcess randomFun do
<- pointProcess
xs <- mapM (const randomFun) xs
fs <- randomFun
default_f let h :: [(Double, Double -> Double)] -> Double -> Double
= default_f x
h [] x : xfs) x | x <= a = f x
h ((a, f) : xfs) x | x > a = h xfs x
h ((a, f) return (h (zip xs fs))
linear
, and for a point process we will use the following
Poisson point process, poissonPP
. This generates an
infinite random list of points, where the gaps between them are
exponentially distributed.
poissonPP :: Double -> Double -> Prob [Double]
=
poissonPP lower rate do
<- exponential rate
step let x = lower + step
<- poissonPP x rate
xs return (x : xs)
=
plotPoissonPP do
<- mh 1 $ sample $ poissonPP 0 0.1
pws let ps = map (takeWhile (20>)) $ map fst $ take 5 $ pws
let filename = "images/regression-poissonpp.svg"
putStrLn $ "Plotting " ++ filename ++ "..."
let myscatter mpl i = mpl % setSubplot i % scatter (ps !! i) (map (const (0::Double)) (ps !! i)) @@ [o2 "s" (10::Int),o2 "c" "black"] % xlim (0::Int) (20::Int) % ylim (-1 :: Int) (1::Int) % mp # "ax.yaxis.set_major_formatter(mticker.NullFormatter())"
let myscatteraxes mpl i = if i < (length ps - 1) then myscatter mpl i % mp # "ax.xaxis.set_major_formatter(mticker.NullFormatter())" else myscatter mpl i
$ foldl myscatteraxes (subplots @@ [o2 "nrows" (length ps),o2 "ncols" (1::Int)]) [0..(length ps - 1)]
file filename putStrLn $ "Done."
We can now invoke a random piecewise linear function by calling splice (poissonPP 0 0.1) linear
.
Here are ten draws from this distribution. Because the viewport is
bounded, laziness takes care of truncations to the point process that we
passed to splice
.
=
plotPiecewisePrior do
<- mh 1 $ sample $ splice (poissonPP 0 0.1) linear
fs' let fs = map fst $ take 10 $ fs'
"images/regression-piecewise-prior.svg" [] fs 1 plotFuns
regress 0.1 (splice (poissonPP 0 0.1) linear)
.
We can then sample from the unnormalized distribution using
Metropolis-Hastings.
=
plotPiecewiseReg do
<- mhirreducible 0.2 0.1 (regress 0.1 (splice (poissonPP 0 0.1) linear) dataset)
fs' let fs = map fst $ take 1000 $ every 1000 $ drop 10000 fs'
"images/regression-piecewise-reg.svg" dataset fs 0.01 plotFuns
randConst
, a random linear function
with slope 0.
randConst :: Prob (Double -> Double)
=
randConst do
<- normal 0 3
b let f = \x -> b
return f
splice (poissonPP 0 0.1) randConst
and perform inference on it to get the resultant unnormalized
distribution of piecewise constant functions.
=
plotPiecewiseConst do
<- mhirreducible 0.2 0.1 (regress 0.1 (splice (poissonPP 0 0.1) randConst) dataset)
fs' let fs = map fst $ take 1000 $ every 1000 $ drop 10000 fs'
"images/regression-piecewise-const.svg" dataset fs 0.01 plotFuns
-- Plot the points drawn from weighted samples
-- epsilon: smallest y axis difference to worry about
-- delta: smallest x axis difference to worry about
interestingPoints :: (Double -> Double) -> Double -> Double -> Double -> Double -> [Double] -> [Double]
=
interestingPoints f lower upper epsilon delta acc if abs(upper - lower) < delta then acc
else
let mid = (upper - lower) / 2 + lower in
if abs((f(upper) - f(lower)) / 2 + f(lower) - f(mid)) < epsilon
then acc
else interestingPoints f lower mid epsilon delta (mid : (interestingPoints f mid upper epsilon delta acc))
=
sampleFun f -- [ (x, f x) | x <- [(-0.25),(-0.25+0.1)..6.2]]
let xs = ((-0.25) : (interestingPoints f (-0.25) 6.2 0.3 0.001 [6.2])) in
map (\x -> (x,f x)) xs
plotFuns :: String -> [(Double,Double)] -> [Double -> Double] -> Double -> IO ()
=
plotFuns filename dataset funs alpha do putStrLn $ "Plotting " ++ filename ++ "..."
$ foldl (\a f -> let xfs = sampleFun f in a % plot (map fst xfs) (map snd xfs) @@ [o1 "go-", o2 "linewidth" (0.5 :: Double), o2 "alpha" alpha, o2 "ms" (0 :: Int)]) (scatter (map fst dataset) (map snd dataset) @@ [o2 "c" "black"] % xlim (0 :: Int) (6 :: Int) % ylim (-2 :: Int) (10 :: Int)) funs
file filename putStrLn "Done."
return ()
main :: IO ()
= do {plotLinearPrior ; plotDataset ; plotLinReg ; plotPiecewisePrior ; plotPoissonPP ; plotPiecewiseReg ; plotPiecewiseConst } main