Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

save PyMC3 traceplot subplots to image file

I am trying very simply to plot subplots generated by the PyMC3 traceplot function (see here) to a file.

The function generates a numpy.ndarray (2d) of subplots.

I need to move or copy these subplots into a matplotlib.figure in order to save the image file. Everything I can find shows how to generate the figure's subplots first, then build them out.

As a minimum example, I lifted the sample PyMC3 code from Here, and added to it just a few lines in an attempt to handle the subplots.

from pymc3 import *
import theano.tensor as tt
from theano import as_op
from numpy import arange, array, empty

### Added these three lines relative to source #######################
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

__all__ = ['disasters_data', 'switchpoint', 'early_mean', 'late_mean', 'rate', 'disasters']

# Time series of recorded coal mining disasters in the UK from 1851 to 1962
disasters_data = array([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
                        3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
                        2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0,
                        1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
                        0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
                        3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
                        0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
years = len(disasters_data)

@as_op(itypes=[tt.lscalar, tt.dscalar, tt.dscalar], otypes=[tt.dvector])
def rateFunc(switchpoint, early_mean, late_mean):
    out = empty(years)
    out[:switchpoint] = early_mean
    out[switchpoint:] = late_mean
    return out


with Model() as model:

    # Prior for distribution of switchpoint location
    switchpoint = DiscreteUniform('switchpoint', lower=0, upper=years)
    # Priors for pre- and post-switch mean number of disasters
    early_mean = Exponential('early_mean', lam=1.)
    late_mean = Exponential('late_mean', lam=1.)

    # Allocate appropriate Poisson rates to years before and after current switchpoint location
    rate = rateFunc(switchpoint, early_mean, late_mean)

    # Data likelihood
    disasters = Poisson('disasters', rate, observed=disasters_data)

    # Initial values for stochastic nodes
    start = {'early_mean': 2., 'late_mean': 3.}

    # Use slice sampler for means
    step1 = Slice([early_mean, late_mean])
    # Use Metropolis for switchpoint, since it accomodates discrete variables
    step2 = Metropolis([switchpoint])

    # njobs>1 works only with most recent (mid August 2014) Thenao version:
    # https://github.com/Theano/Theano/pull/2021
    tr = sample(1000, tune=500, start=start, step=[step1, step2], njobs=1)

    ### gnashing of teeth starts here ################################
    fig, axarr = plt.subplots(3,2)

    # This gives a KeyError
    # axarr = traceplot(tr, axarr)

    # This finishes without error
    trarr = traceplot(tr)

    # doesn't work
    # axarr[0, 0] = trarr[0, 0]

    fig.savefig("disaster.png")

I've tried a few variations along the subplot() and add_subplot() lines, to no avail -- all errors point toward the fact that empty subplots must first be created for the figure, not assigned to pre-existing subplots.

A different example (see here, about 80% of the way down, beginning with

### Mysterious code to be explained in Chapter 3.

) avoids the utility altogether and builds out the subplots manually, so maybe there's no good answer to this? Is the pymc3.traceplot output indeed an orphaned ndarray of subplots that can't be used?

like image 365
GoneAsync Avatar asked Nov 15 '25 22:11

GoneAsync


1 Answers

I ran into the same problem. I am working with pymc3 3.5 and matplotlib 2.1.2.

I realized it's possible to export the traceplot by:

trarr = traceplot(tr)

fig = plt.gcf() # to get the current figure...
fig.savefig("disaster.png") # and save it directly
like image 120
Xiaoyu Lu Avatar answered Nov 18 '25 13:11

Xiaoyu Lu