#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 31 12:54:16 2018
@author: prmiles
"""
# import required packages
from __future__ import division
import math
import matplotlib.pyplot as plt
from pylab import hist
from .utilities import generate_names, setup_plot_features, make_x_grid
import warnings
try:
from statsmodels.nonparametric.kernel_density import KDEMultivariate
except ImportError as e:
warnings.warn(str("Exception raised importing statsmodels.nonparametric.kernel_density \
- plot_density_panel will not work. {}".format(e)))
# --------------------------------------------
[docs]def plot_density_panel(chains, names=None, hist_on=False, figsizeinches=None):
'''
Plot marginal posterior densities
Args:
* **chains** (:class:`~numpy.ndarray`): Sampling chain for each parameter
* **names** (:py:class:`list`): List of strings - name of each parameter
* **hist_on** (:py:class:`bool`): Flag to include histogram on density plot
* **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height]
'''
nsimu, nparam = chains.shape # number of rows, number of columns
ns1, ns2, names, figsizeinches = setup_plot_features(nparam=nparam, names=names, figsizeinches=figsizeinches)
f = plt.figure(dpi=100, figsize=(figsizeinches)) # initialize figure
for ii in range(nparam):
# define chain
chain = chains[:, ii].reshape(nsimu, 1) # check indexing
# define x grid
chain_grid = make_x_grid(chain)
# Compute kernel density estimate
kde = KDEMultivariate(chain, bw='normal_reference', var_type='c')
# plot density on subplot
plt.subplot(ns1, ns2, ii+1)
if hist_on is True: # include histograms
hist(chain, density=True)
plt.plot(chain_grid, kde.pdf(chain_grid), 'k')
# format figure
plt.xlabel(names[ii])
plt.ylabel(str('$\\pi$({}$|M^{}$)'.format(names[ii], '{data}')))
plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing
return f
# --------------------------------------------
[docs]def plot_histogram_panel(chains, names=None, figsizeinches=None):
"""
Plot histogram from each parameter's sampling history
Args:
* **chains** (:class:`~numpy.ndarray`): Sampling chain for each parameter
* **names** (:py:class:`list`): List of strings - name of each parameter
* **hist_on** (:py:class:`bool`): Flag to include histogram on density plot
* **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height]
"""
nsimu, nparam = chains.shape # number of rows, number of columns
ns1, ns2, names, figsizeinches = setup_plot_features(nparam=nparam, names=names, figsizeinches=figsizeinches)
f = plt.figure(dpi=100, figsize=(figsizeinches)) # initialize figure
for ii in range(nparam):
# define chain
chain = chains[:, ii].reshape(nsimu, 1) # check indexing
# plot density on subplot
ax = plt.subplot(ns1, ns2, ii+1)
hist(chain, density=True)
# format figure
plt.xlabel(names[ii])
ax.set_yticklabels([])
plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing
return f
# --------------------------------------------
[docs]def plot_chain_panel(chains, names=None, figsizeinches=None, maxpoints=500):
"""
Plot sampling chain for each parameter
Args:
* **chains** (:class:`~numpy.ndarray`): Sampling chain for each parameter
* **names** (:py:class:`list`): List of strings - name of each parameter
* **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height]
* **maxpoints** (:py:class:`int`): Max number of display points - keeps scatter plot from becoming overcrowded
"""
nsimu, nparam = chains.shape # number of rows, number of columns
ns1, ns2, names, figsizeinches = setup_plot_features(nparam=nparam, names=names, figsizeinches=figsizeinches)
skip = 1
if nsimu > maxpoints:
skip = int(math.floor(nsimu/maxpoints))
f = plt.figure(dpi=100, figsize=(figsizeinches)) # initialize figure
for ii in range(nparam):
# define chain
chain = chains[:, ii].reshape(nsimu, 1) # check indexing
# plot chain on subplot
plt.subplot(ns1, ns2, ii+1)
plt.plot(range(0, nsimu, skip), chain[range(0, nsimu, skip), 0], '.b')
# format figure
plt.xlabel('Iteration')
plt.ylabel(str('{}'.format(names[ii])))
if ii+1 <= ns1*ns2 - ns2:
plt.xlabel('')
plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing
return f
# --------------------------------------------
[docs]def plot_pairwise_correlation_panel(chains, names=None, figsizeinches=None, skip=1):
"""
Plot pairwise correlation for each parameter
Args:
* **chains** (:class:`~numpy.ndarray`): Sampling chain for each parameter
* **names** (:py:class:`list`): List of strings - name of each parameter
* **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height]
* **skip** (:py:class:`int`): Indicates step size to be used when plotting elements from the chain
"""
nsimu, nparam = chains.shape # number of rows, number of columns
inds = range(0, nsimu, skip)
names = generate_names(nparam=nparam, names=names)
if figsizeinches is None:
figsizeinches = [7, 5]
f = plt.figure(dpi=100, figsize=(figsizeinches)) # initialize figure
for jj in range(2, nparam+1):
for ii in range(1, jj):
chain1 = chains[inds, ii-1]
chain1 = chain1.reshape(nsimu, 1)
chain2 = chains[inds, jj-1]
chain2 = chain2.reshape(nsimu, 1)
# plot density on subplot
ax = plt.subplot(nparam-1, nparam-1, (jj-2)*(nparam-1)+ii)
plt.plot(chain1, chain2, '.b')
# format figure
if jj != nparam: # rm xticks
ax.set_xticklabels([])
if ii != 1: # rm yticks
ax.set_yticklabels([])
if ii == 1: # add ylabels
plt.ylabel(str('{}'.format(names[jj-1])))
if ii == jj - 1:
if nparam == 2: # add xlabels
plt.xlabel(str('{}'.format(names[ii-1])))
else: # add title
plt.title(str('{}'.format(names[ii-1])))
# adjust figure margins
plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing
return f
# --------------------------------------------
[docs]def plot_chain_metrics(chain, name=None, figsizeinches=None):
'''
Plot chain metrics for individual chain
- Scatter plot of chain
- Histogram of chain
Args:
* **chains** (:class:`~numpy.ndarray`): Sampling chain for specific parameter
* **names** (:py:class:`str`): Name of each parameter
* **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height]
'''
name = generate_names(nparam=1, names=name)
if figsizeinches is None:
figsizeinches = [7, 5]
f = plt.figure(dpi=100, figsize=(figsizeinches)) # initialize figure
plt.suptitle('Chain metrics for {}'.format(name), fontsize='12')
plt.subplot(2, 1, 1)
plt.plot(range(0, len(chain)), chain, marker='.')
# format figure
plt.xlabel('Iterations')
ystr = str('{}-chain'.format(name))
plt.ylabel(ystr)
# Add histogram
plt.subplot(2, 1, 2)
hist(chain)
# format figure
plt.xlabel(name)
plt.ylabel(str('Histogram of {}-chain'.format(name)))
plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing
return f
[docs]class Plot:
'''
Plotting routines for analyzing sampling chains from MCMC process.
Attributes:
- :meth:`~plot_density_panel`
- :meth:`~plot_chain_panel`
- :meth:`~plot_pairwise_correlation_panel`
- :meth:`~plot_histogram_panel`
- :meth:`~plot_chain_metrics`
'''
def __init__(self):
self.plot_density_panel = plot_density_panel
self.plot_chain_panel = plot_chain_panel
self.plot_pairwise_correlation_panel = plot_pairwise_correlation_panel
self.plot_histogram_panel = plot_histogram_panel
self.plot_chain_metrics = plot_chain_metrics