Source code for pymcmcstat.plotting.MCMCPlotting

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 31 12:54:16 2018

@author: prmiles
"""

# import required packages
from __future__ import division
import matplotlib.pyplot as plt
from pylab import hist
from .utilities import generate_names, setup_plot_features, make_x_grid
from .utilities import setup_subsample
import warnings
from deprecated import deprecated


try:
    from statsmodels.nonparametric.kernel_density import KDEMultivariate
except ImportError as e:
    warnings.warn(str("Exception raised importing statsmodels.nonparametric.kernel_density \
                      - plot_density_panel will not work. {}".format(e)))


[docs]def deprecation(message): warnings.warn(message, DeprecationWarning)
# --------------------------------------------
[docs]@deprecated(version='1.9.0', reason='New function: "from pymcmcstat.mcmcplot import plot_density_panel"') def plot_density_panel(chains, names=None, hist_on=False, figsizeinches=None, return_kde=False): ''' Plot marginal posterior densities Args: * **chains** (:class:`~numpy.ndarray`): Sampling chain for each parameter * **names** (:py:class:`list`): List of strings - name of each parameter * **hist_on** (:py:class:`bool`): Flag to include histogram on density plot * **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height] ''' deprecation('Recommend using pymcmcstat.mcmcplot.plot_density_panel') nsimu, nparam = chains.shape # number of rows, number of columns ns1, ns2, names, figsizeinches = setup_plot_features( nparam=nparam, names=names, figsizeinches=figsizeinches) f = plt.figure(dpi=100, figsize=(figsizeinches)) # initialize figure kdehandle = [] for ii in range(nparam): # define chain chain = chains[:, ii].reshape(nsimu, 1) # check indexing # define x grid chain_grid = make_x_grid(chain) # Compute kernel density estimate kde = KDEMultivariate(chain, bw='normal_reference', var_type='c') # plot density on subplot plt.subplot(ns1, ns2, ii+1) if hist_on is True: # include histograms hist(chain, density=True) plt.plot(chain_grid, kde.pdf(chain_grid), 'k') # format figure plt.xlabel(names[ii]) plt.ylabel(str('$\\pi$({}$|M^{}$)'.format(names[ii], '{data}'))) plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing kdehandle.append(kde) if return_kde is True: return f, kdehandle else: return f
# --------------------------------------------
[docs]@deprecated(version='1.9.0', reason='New function: "from pymcmcstat.mcmcplot import plot_histogram_panel"') def plot_histogram_panel(chains, names=None, figsizeinches=None): """ Plot histogram from each parameter's sampling history Args: * **chains** (:class:`~numpy.ndarray`): Sampling chain for each parameter * **names** (:py:class:`list`): List of strings - name of each parameter * **hist_on** (:py:class:`bool`): Flag to include histogram on density plot * **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height] """ nsimu, nparam = chains.shape # number of rows, number of columns ns1, ns2, names, figsizeinches = setup_plot_features( nparam=nparam, names=names, figsizeinches=figsizeinches) f = plt.figure(dpi=100, figsize=(figsizeinches)) # initialize figure for ii in range(nparam): # define chain chain = chains[:, ii].reshape(nsimu, 1) # check indexing # plot density on subplot ax = plt.subplot(ns1, ns2, ii+1) hist(chain, density=True) # format figure plt.xlabel(names[ii]) ax.set_yticklabels([]) plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing return f
# --------------------------------------------
[docs]@deprecated(version='1.9.0', reason='New function:"from pymcmcstat.mcmcplot import plot_chain_panel"') def plot_chain_panel(chains, names=None, figsizeinches=None, skip=1, maxpoints=500): """ Plot sampling chain for each parameter Args: * **chains** (:class:`~numpy.ndarray`): Sampling chain for each parameter * **names** (:py:class:`list`): List of strings - name of each parameter * **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height] * **skip** (:py:class:`int`): Indicates step size to be used when plotting elements from the chain * **maxpoints** (:py:class:`int`): Max number of display points - keeps scatter plot from becoming overcrowded """ nsimu, nparam = chains.shape # number of rows, number of columns ns1, ns2, names, figsizeinches = setup_plot_features( nparam=nparam, names=names, figsizeinches=figsizeinches) inds = setup_subsample(skip, maxpoints, nsimu) f = plt.figure(dpi=100, figsize=(figsizeinches)) # initialize figure for ii in range(nparam): # define chain chain = chains[inds, ii] # check indexing # plot chain on subplot plt.subplot(ns1, ns2, ii+1) plt.plot(inds, chain, '.b') # format figure plt.xlabel('Iteration') plt.ylabel(str('{}'.format(names[ii]))) if ii+1 <= ns1*ns2 - ns2: plt.xlabel('') plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing return f
# --------------------------------------------
[docs]@deprecated(version='1.9.0', reason='New function: "from pymcmcstat.mcmcplot import plot_pairwise_correlation_panel"') def plot_pairwise_correlation_panel(chains, names=None, figsizeinches=None, skip=1, maxpoints=500): """ Plot pairwise correlation for each parameter Args: * **chains** (:class:`~numpy.ndarray`): Sampling chain for each parameter * **names** (:py:class:`list`): List of strings - name of each parameter * **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height] * **skip** (:py:class:`int`): Indicates step size to be used when plotting elements from the chain * **maxpoints** (py:class:`int`): Maximum allowable number of points in plot. """ nsimu, nparam = chains.shape # number of rows, number of columns inds = setup_subsample(skip, maxpoints, nsimu) names = generate_names(nparam=nparam, names=names) if figsizeinches is None: figsizeinches = [7, 5] f = plt.figure(dpi=100, figsize=(figsizeinches)) # initialize figure for jj in range(2, nparam + 1): for ii in range(1, jj): chain1 = chains[inds, ii - 1] chain2 = chains[inds, jj - 1] # plot density on subplot ax = plt.subplot(nparam - 1, nparam - 1, (jj - 2)*(nparam - 1)+ii) plt.plot(chain1, chain2, '.b') # format figure if jj != nparam: # rm xticks ax.set_xticklabels([]) if ii != 1: # rm yticks ax.set_yticklabels([]) if ii == 1: # add ylabels plt.ylabel(str('{}'.format(names[jj - 1]))) if ii == jj - 1: if nparam == 2: # add xlabels plt.xlabel(str('{}'.format(names[ii - 1]))) else: # add title plt.title(str('{}'.format(names[ii - 1]))) # adjust figure margins plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing return f
# --------------------------------------------
[docs]@deprecated(version='1.9.0', reason='New function: "from pymcmcstat.mcmcplot import plot_chain_metrics"') def plot_chain_metrics(chain, name=None, figsizeinches=None): ''' Plot chain metrics for individual chain - Scatter plot of chain - Histogram of chain Args: * **chains** (:class:`~numpy.ndarray`): Sampling chain for specific parameter * **names** (:py:class:`str`): Name of each parameter * **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height] ''' name = generate_names(nparam=1, names=name) if figsizeinches is None: figsizeinches = [7, 5] f = plt.figure(dpi=100, figsize=(figsizeinches)) # initialize figure plt.suptitle('Chain metrics for {}'.format(name), fontsize='12') plt.subplot(2, 1, 1) plt.plot(range(0, len(chain)), chain, marker='.') # format figure plt.xlabel('Iterations') ystr = str('{}-chain'.format(name)) plt.ylabel(ystr) # Add histogram plt.subplot(2, 1, 2) hist(chain) # format figure plt.xlabel(name) plt.ylabel(str('Histogram of {}-chain'.format(name))) plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing return f
[docs]class Plot: ''' Plotting routines for analyzing sampling chains from MCMC process. Attributes: - :meth:`~plot_density_panel` - :meth:`~plot_chain_panel` - :meth:`~plot_pairwise_correlation_panel` - :meth:`~plot_histogram_panel` - :meth:`~plot_chain_metrics` ''' def __init__(self): self.plot_density_panel = plot_density_panel self.plot_chain_panel = plot_chain_panel self.plot_pairwise_correlation_panel = plot_pairwise_correlation_panel self.plot_histogram_panel = plot_histogram_panel self.plot_chain_metrics = plot_chain_metrics