Source code for pymcmcstat.propagation

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Nov  8 12:00:11 2017

@author: prmiles
"""

import numpy as np
import sys
from .utilities.progressbar import progress_bar
from .utilities.general import check_settings
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import colors as mplcolor
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from scipy.interpolate import interp1d


[docs]def calculate_intervals(chain, results, data, model, s2chain=None, nsample=500, waitbar=True, sstype=0): ''' Calculate distribution of model response to form propagation intervals Samples values from chain, performs forward model evaluation, and tabulates credible and prediction intervals (if obs. error var. included). Args: * **chain** (:class:`~numpy.ndarray`): Parameter chains, expect shape=(nsimu, npar). * **results** (:py:class:`dict`): Results dictionary generated by pymcmcstat. * **data** (:class:`~.DataStructure`): Data * **model**: User defined function. Note, if your model outputs multiple quantities of interest (QoI) at the same time in a multi-dimensional array, then make sure it is returned as a (N, p) array where N is the number of evaluation points and p is the number of QoI. Kwargs: * **s2chain** (:py:class:`float`, :class:`~numpy.ndarray`, or None): Observation error variance chain. * **nsample** (:py:class:`int`): No. of samples drawn from posteriors. * **waitbar** (:py:class:`bool`): Flag to display progress bar. * **sstype** (:py:class:`int`): Sum-of-squares type. Can be 0 (normal), 1 (sqrt), or 2 (log). Returns: * :py:class:`dict` with two elements: 1) `credible` and 2) `prediction` ''' parind = results['parind'] q = results['theta'] nsimu, npar = chain.shape s2chain = check_s2chain(s2chain, nsimu) iisample, nsample = define_sample_points(nsample, nsimu) if waitbar is True: __wbarstatus = progress_bar(iters=int(nsample)) ci = [] pi = [] multiple = False for kk, isa in enumerate(iisample): # progress bar if waitbar is True: __wbarstatus.update(kk) # extract chain set q[parind] = chain[kk, :] # evaluate model y = model(q, data) # check model output if y.ndim == 2: nrow, ncol = y.shape if nrow != y.size and ncol != y.size: multiple = True if multiple is False: # store model prediction in credible intervals ci.append(y.reshape(y.size,)) # store model output if s2chain is None: continue else: # estimate prediction intervals s2 = s2chain[kk] obs = observation_sample(s2, y, sstype) pi.append(obs.reshape(obs.size,)) else: # Model output contains multiple QoI # Expect ncol = No. of QoI if kk == 0: cis = [] pis = [] for jj in range(ncol): cis.append([]) pis.append([]) for jj in range(ncol): # store model prediction in credible intervals cis[jj].append(y[:, jj]) # store model output if s2chain is None: continue else: # estimate prediction intervals if s2chain.ndim == 2: if s2chain.shape[1] == ncol: s2 = s2chain[kk, jj] else: s2 = s2chain[kk] else: s2 = s2chain[kk] obs = observation_sample(s2, y[:, jj], sstype) pis[jj].append(obs.reshape(obs.size,)) if multiple is False: # Setup output credible = np.array(ci) if s2chain is None: prediction = None else: prediction = np.array(pi) return dict(credible=credible, prediction=prediction) else: # Setup output for multiple QoI out = [] for jj in range(ncol): credible = np.array(cis[jj]) if s2chain is None: prediction = None else: prediction = np.array(pis[jj]) out.append(dict(credible=credible, prediction=prediction)) return out
# --------------------------------------------
[docs]def plot_intervals(intervals, time, ydata=None, xdata=None, limits=[95], adddata=None, addmodel=True, addlegend=True, addcredible=True, addprediction=True, data_display={}, model_display={}, interval_display={}, fig=None, figsize=None, legloc='upper left', ciset=None, piset=None, return_settings=False): ''' Plot propagation intervals in 2-D This routine takes the model distributions generated using the :func:`~calculate_intervals` method and then plots specific quantiles. The user can plot just the intervals, or also include the median model response and/or observations. Specific settings for credible intervals are controlled by defining the `ciset` dictionary. Likewise, for prediction intervals, settings are defined using `piset`. The setting options available for each interval are as follows: - `limits`: This should be a list of numbers between 0 and 100, e.g., `limits=[50, 90]` will result in 50% and 90% intervals. - `cmap`: The program is designed to "try" to choose colors that are visually distinct. The user can specify the colormap to choose from. - `colors`: The user can specify the color they would like for each interval in a list, e.g., ['r', 'g', 'b']. This list should have the same number of elements as `limits` or the code will revert back to its default behavior. Args: * **intervals** (:py:class:`dict`): Interval dictionary generated using :meth:`calculate_intervals` method. * **time** (:class:`~numpy.ndarray`): Independent variable, i.e., x-axis of plot Kwargs: * **ydata** (:class:`~numpy.ndarray` or None): Observations, expect 1-D array if defined. * **xdata** (:class:`~numpy.ndarray` or None): Independent values corresponding to observations. This is required if the observations do not align with your times of generating the model response. * **limits** (:py:class:`list`): Quantile limits that correspond to percentage size of desired intervals. Note, this is the default limits, but specific limits can be defined using the `ciset` and `piset` dictionaries. * **adddata** (:py:class:`bool`): Flag to include data * **addmodel** (:py:class:`bool`): Flag to include median model response * **addlegend** (:py:class:`bool`): Flag to include legend * **addcredible** (:py:class:`bool`): Flag to include credible intervals * **addprediction** (:py:class:`bool`): Flag to include prediction intervals * **model_display** (:py:class:`dict`): Display settings for median model response * **data_display** (:py:class:`dict`): Display settings for data * **interval_display** (:py:class:`dict`): General display settings for intervals. * **fig**: Handle of previously created figure object * **figsize** (:py:class:`tuple`): (width, height) in inches * **legloc** (:py:class:`str`): Legend location - matplotlib help for details. * **ciset** (:py:class:`dict`): Settings for credible intervals * **piset** (:py:class:`dict`): Settings for prediction intervals * **return_settings** (:py:class:`bool`): Flag to return ciset and piset along with fig and ax. Returns: * (:py:class:`tuple`) with elements 1) Figure handle 2) Axes handle 3) Dictionary with `ciset` and `piset` inside (only outputted if `return_settings=True`) ''' # unpack dictionary credible = intervals['credible'] prediction = intervals['prediction'] # Check user-defined settings ciset = __setup_iset(ciset, default_iset=dict( limits=limits, cmap=None, colors=None)) piset = __setup_iset(piset, default_iset=dict( limits=limits, cmap=None, colors=None)) # Check limits ciset['limits'] = _check_limits(ciset['limits'], limits) piset['limits'] = _check_limits(piset['limits'], limits) # convert limits to ranges ciset['quantiles'] = _convert_limits(ciset['limits']) piset['quantiles'] = _convert_limits(piset['limits']) # setup display settings interval_display, model_display, data_display = setup_display_settings( interval_display, model_display, data_display) # Define colors ciset['colors'] = setup_interval_colors(ciset, inttype='ci') piset['colors'] = setup_interval_colors(piset, inttype='pi') # Define labels ciset['labels'] = _setup_labels(ciset['limits'], inttype='CI') piset['labels'] = _setup_labels(piset['limits'], inttype='PI') if fig is None: fig = plt.figure(figsize=figsize) ax = fig.gca() time = time.reshape(time.size,) # add prediction intervals if addprediction is True: for ii, quantile in enumerate(piset['quantiles']): pi = generate_quantiles(prediction, np.array(quantile)) ax.fill_between(time, pi[0], pi[1], facecolor=piset['colors'][ii], label=piset['labels'][ii], **interval_display) # add credible intervals if addcredible is True: for ii, quantile in enumerate(ciset['quantiles']): ci = generate_quantiles(credible, np.array(quantile)) ax.fill_between(time, ci[0], ci[1], facecolor=ciset['colors'][ii], label=ciset['labels'][ii], **interval_display) # add model (median model response) if addmodel is True: ci = generate_quantiles(credible, np.array(0.5)) ax.plot(time, ci, **model_display) # add data to plot if ydata is not None and adddata is None: adddata = True if adddata is True and ydata is not None: if xdata is None: ax.plot(time, ydata, **data_display) else: ax.plot(xdata, ydata, **data_display) # add legend if addlegend is True: handles, labels = ax.get_legend_handles_labels() ax.legend(handles, labels, loc=legloc) if return_settings is True: return fig, ax, dict(ciset=ciset, piset=piset) else: return fig, ax
# --------------------------------------------
[docs]def plot_3d_intervals(intervals, time, ydata=None, xdata=None, limits=[95], adddata=False, addlegend=True, addmodel=True, figsize=None, model_display={}, data_display={}, interval_display={}, addcredible=True, addprediction=True, fig=None, legloc='upper left', ciset=None, piset=None, return_settings=False): ''' Plot propagation intervals in 3-D This routine takes the model distributions generated using the :func:`~calculate_intervals` method and then plots specific quantiles. The user can plot just the intervals, or also include the median model response and/or observations. Specific settings for credible intervals are controlled by defining the `ciset` dictionary. Likewise, for prediction intervals, settings are defined using `piset`. The setting options available for each interval are as follows: - `limits`: This should be a list of numbers between 0 and 100, e.g., `limits=[50, 90]` will result in 50% and 90% intervals. - `cmap`: The program is designed to "try" to choose colors that are visually distinct. The user can specify the colormap to choose from. - `colors`: The user can specify the color they would like for each interval in a list, e.g., ['r', 'g', 'b']. This list should have the same number of elements as `limits` or the code will revert back to its default behavior. Args: * **intervals** (:py:class:`dict`): Interval dictionary generated using :meth:`calculate_intervals` method. * **time** (:class:`~numpy.ndarray`): Independent variable, i.e., x- and y-axes of plot. Note, it must be a 2-D array with shape=(N, 2), where N is the number of evaluation points. Kwargs: * **ydata** (:class:`~numpy.ndarray` or None): Observations, expect 1-D array if defined. * **xdata** (:class:`~numpy.ndarray` or None): Independent values corresponding to observations. This is required if the observations do not align with your times of generating the model response. * **limits** (:py:class:`list`): Quantile limits that correspond to percentage size of desired intervals. Note, this is the default limits, but specific limits can be defined using the `ciset` and `piset` dictionaries. * **adddata** (:py:class:`bool`): Flag to include data * **addmodel** (:py:class:`bool`): Flag to include median model response * **addlegend** (:py:class:`bool`): Flag to include legend * **addcredible** (:py:class:`bool`): Flag to include credible intervals * **addprediction** (:py:class:`bool`): Flag to include prediction intervals * **model_display** (:py:class:`dict`): Display settings for median model response * **data_display** (:py:class:`dict`): Display settings for data * **interval_display** (:py:class:`dict`): General display settings for intervals. * **fig**: Handle of previously created figure object * **figsize** (:py:class:`tuple`): (width, height) in inches * **legloc** (:py:class:`str`): Legend location - matplotlib help for details. * **ciset** (:py:class:`dict`): Settings for credible intervals * **piset** (:py:class:`dict`): Settings for prediction intervals * **return_settings** (:py:class:`bool`): Flag to return ciset and piset along with fig and ax. Returns: * (:py:class:`tuple`) with elements 1) Figure handle 2) Axes handle 3) Dictionary with `ciset` and `piset` inside (only outputted if `return_settings=True`) ''' # unpack dictionary credible = intervals['credible'] prediction = intervals['prediction'] # Check user-defined settings ciset = __setup_iset(ciset, default_iset=dict( limits=limits, cmap=None, colors=None)) piset = __setup_iset(piset, default_iset=dict( limits=limits, cmap=None, colors=None)) # Check limits ciset['limits'] = _check_limits(ciset['limits'], limits) piset['limits'] = _check_limits(piset['limits'], limits) # convert limits to ranges ciset['quantiles'] = _convert_limits(ciset['limits']) piset['quantiles'] = _convert_limits(piset['limits']) # setup display settings interval_display, model_display, data_display = setup_display_settings( interval_display, model_display, data_display) # Define colors ciset['colors'] = setup_interval_colors(ciset, inttype='ci') piset['colors'] = setup_interval_colors(piset, inttype='pi') # Define labels ciset['labels'] = _setup_labels(ciset['limits'], inttype='CI') piset['labels'] = _setup_labels(piset['limits'], inttype='PI') if fig is None: fig = plt.figure(figsize=figsize) ax = Axes3D(fig) ax = fig.gca() time1 = time[:, 0] time2 = time[:, 1] # add prediction intervals if addprediction is True: for ii, quantile in enumerate(piset['quantiles']): pi = generate_quantiles(prediction, np.array(quantile)) # Add a polygon instead of fill_between rev = np.arange(time1.size - 1, -1, -1) x = np.concatenate((time1, time1[rev])) y = np.concatenate((time2, time2[rev])) z = np.concatenate((pi[0], pi[1][rev])) verts = [list(zip(x, y, z))] surf = Poly3DCollection(verts, color=piset['colors'][ii], label=piset['labels'][ii]) # Add fix for legend compatibility surf._facecolors2d = surf._facecolors3d surf._edgecolors2d = surf._edgecolors3d ax.add_collection3d(surf) # add credible intervals if addcredible is True: for ii, quantile in enumerate(ciset['quantiles']): ci = generate_quantiles(credible, np.array(quantile)) # Add a polygon instead of fill_between rev = np.arange(time1.size - 1, -1, -1) x = np.concatenate((time1, time1[rev])) y = np.concatenate((time2, time2[rev])) z = np.concatenate((ci[0], ci[1][rev])) verts = [list(zip(x, y, z))] surf = Poly3DCollection(verts, color=ciset['colors'][ii], label=ciset['labels'][ii]) # Add fix for legend compatibility surf._facecolors2d = surf._facecolors3d surf._edgecolors2d = surf._edgecolors3d ax.add_collection3d(surf) # add model (median model response) if addmodel is True: ci = generate_quantiles(credible, np.array(0.5)) ax.plot(time1, time2, ci, **model_display) # add data to plot if ydata is not None and adddata is None: adddata = True if adddata is True: if xdata is None: ax.plot(time1, time2, ydata.reshape(time1.shape), **data_display) else: # User provided xdata array for observation points ax.plot(xdata[:, 0], xdata[:, 1], ydata.reshape(time1.shape), **data_display) # add legend if addlegend is True: handles, labels = ax.get_legend_handles_labels() ax.legend(handles, labels, loc=legloc) if return_settings is True: return fig, ax, dict(ciset=ciset, piset=piset) else: return fig, ax
[docs]def check_s2chain(s2chain, nsimu): ''' Check size of s2chain Args: * **s2chain** (:py:class:`float`, :class:`~numpy.ndarray`, or `None`): Observation error variance chain or value * **nsimu** (:py:class:`int`): No. of elements in chain Returns: * **s2chain** (:class:`~numpy.ndarray` or `None`) ''' if s2chain is None: return None else: if isinstance(s2chain, float): s2chain = np.ones((nsimu,))*s2chain if s2chain.ndim == 2: if s2chain.shape[0] != nsimu: s2chain = s2chain * np.ones((nsimu, s2chain.size)) else: if s2chain.size != nsimu: # scalars provided for multiple QoI s2chain = s2chain * np.ones((nsimu, s2chain.size)) return s2chain
# --------------------------------------------
[docs]def observation_sample(s2, y, sstype): ''' Calculate model response with observation errors. Args: * **s2** (:class:`~numpy.ndarray`): Observation error(s). * **y** (:class:`~numpy.ndarray`): Model responses. * **sstype** (:py:class:`int`): Flag to specify sstype. Returns: * **opred** (:class:`~numpy.ndarray`): Model responses with observation errors. ''' if sstype == 0: opred = y + np.random.standard_normal(y.shape) * np.sqrt(s2) elif sstype == 1: # sqrt opred = (np.sqrt(y) + np.random.standard_normal(y.shape) * np.sqrt(s2))**2 elif sstype == 2: # log opred = y*np.exp(np.random.standard_normal(y.shape) * np.sqrt(s2)) else: sys.exit('Unknown sstype') return opred
# --------------------------------------------
[docs]def define_sample_points(nsample, nsimu): ''' Define indices to sample from posteriors. Args: * **nsample** (:py:class:`int`): Number of samples to draw from posterior. * **nsimu** (:py:class:`int`): Number of MCMC simulations. Returns: * **iisample** (:class:`~numpy.ndarray`): Array of indices in posterior set. * **nsample** (:py:class:`int`): Number of samples to draw from posterior. ''' # define sample points if nsample >= nsimu: iisample = range(nsimu) # sample all points from chain nsample = nsimu else: # randomly sample from chain iisample = np.ceil(np.random.rand(nsample)*nsimu) - 1 iisample = iisample.astype(int) return iisample, nsample
# --------------------------------------------
[docs]def generate_quantiles(x, p=np.array([0.25, 0.5, 0.75])): ''' Calculate empirical quantiles. Args: * **x** (:class:`~numpy.ndarray`): Observations from which to generate quantile. * **p** (:class:`~numpy.ndarray`): Quantile limits. Returns: * (:class:`~numpy.ndarray`): Interpolated quantiles. ''' # extract number of rows/cols from np.array n = x.shape[0] # define vector valued interpolation function xpoints = np.arange(0, n, 1) interpfun = interp1d(xpoints, np.sort(x, 0), axis=0) # evaluation points itpoints = (n - 1)*p return interpfun(itpoints)
[docs]def setup_display_settings(interval_display, model_display, data_display): ''' Compare user defined display settings with defaults and merge. Args: * **interval_display** (:py:class:`dict`): User defined settings for interval display. * **model_display** (:py:class:`dict`): User defined settings for model display. * **data_display** (:py:class:`dict`): User defined settings for data display. Returns: * **interval_display** (:py:class:`dict`): Settings for interval display. * **model_display** (:py:class:`dict`): Settings for model display. * **data_display** (:py:class:`dict`): Settings for data display. ''' # Setup interval display default_interval_display = dict( linestyle=':', linewidth=1, alpha=1.0, edgecolor='k') interval_display = check_settings(default_interval_display, interval_display) # Setup model display default_model_display = dict( linestyle='-', color='k', marker='', linewidth=2, markersize=5, label='Model') model_display = check_settings(default_model_display, model_display) # Setup data display default_data_display = dict( linestyle='', color='b', marker='.', linewidth=1, markersize=5, label='Data') data_display = check_settings(default_data_display, data_display) return interval_display, model_display, data_display
[docs]def setup_interval_colors(iset, inttype='CI'): ''' Setup colors for empirical intervals This routine attempts to distribute the color of the UQ intervals based on a normalize color map. Or, it will assign user-defined colors; however, this only happens if the correct number of colors are specified. Args: * **iset** (:py:class:`dict`): This dictionary should contain the following keys - `limits`, `cmap`, and `colors`. Kwargs: * **inttype** (:py:class:`str`): Type of uncertainty interval Returns: * **ic** (:py:class:`list`): List containing color for each interval ''' limits, cmap, colors = iset['limits'], iset['cmap'], iset['colors'] norm = __setup_cmap_norm(limits) cmap = __setup_default_cmap(cmap, inttype) # assign colors using color map or using colors defined by user ic = [] if colors is None: # No user defined colors for limits in limits: ic.append(cmap(norm(limits))) else: if len(colors) == len(limits): # correct number of colors defined for color in colors: ic.append(color) else: # User defined the wrong number of colors print('Note, user-defined colors were ignored. Using color map. ' + 'Expected a list of length {}, but received {}'.format( len(limits), len(colors))) for limits in limits: ic.append(cmap(norm(limits))) return ic
# -------------------------------------------- def _setup_labels(limits, inttype='CI'): ''' Setup labels for prediction/credible intervals. ''' labels = [] for limit in limits: labels.append(str('{}% {}'.format(limit, inttype))) return labels def _check_limits(limits, default_limits): if limits is None: limits = default_limits limits.sort(reverse=True) return limits def _convert_limits(limits): rng = [] for limit in limits: limit = limit/100 rng.append([0.5 - limit/2, 0.5 + limit/2]) return rng def __setup_iset(iset, default_iset): ''' Setup interval settings by comparing user input to default ''' if iset is None: iset = {} iset = check_settings(default_iset, iset) return iset def __setup_cmap_norm(limits): if len(limits) == 1: norm = mplcolor.Normalize(vmin=0, vmax=100) else: norm = mplcolor.Normalize(vmin=min(limits), vmax=max(limits)) return norm def __setup_default_cmap(cmap, inttype): if cmap is None: if inttype.upper() == 'CI': cmap = cm.autumn else: cmap = cm.winter return cmap