#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Nov 8 12:00:11 2017
@author: prmiles
"""
import numpy as np
import sys
from .utilities.progressbar import progress_bar
from .utilities.general import check_settings
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import colors as mplcolor
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from scipy.interpolate import interp1d
[docs]def calculate_intervals(chain, results, data, model, s2chain=None,
nsample=500, waitbar=True, sstype=0):
'''
Calculate distribution of model response to form propagation intervals
Samples values from chain, performs forward model evaluation, and
tabulates credible and prediction intervals (if obs. error var. included).
Args:
* **chain** (:class:`~numpy.ndarray`): Parameter chains, expect
shape=(nsimu, npar).
* **results** (:py:class:`dict`): Results dictionary generated by
pymcmcstat.
* **data** (:class:`~.DataStructure`): Data
* **model**: User defined function. Note, if your model outputs
multiple quantities of interest (QoI) at the same time in a
multi-dimensional array, then make sure it is returned as a
(N, p) array where N is the number of evaluation points and
p is the number of QoI.
Kwargs:
* **s2chain** (:py:class:`float`, :class:`~numpy.ndarray`, or None):
Observation error variance chain.
* **nsample** (:py:class:`int`): No. of samples drawn from posteriors.
* **waitbar** (:py:class:`bool`): Flag to display progress bar.
* **sstype** (:py:class:`int`): Sum-of-squares type. Can be 0 (normal),
1 (sqrt), or 2 (log).
Returns:
* :py:class:`dict` with two elements: 1) `credible` and 2) `prediction`
'''
parind = results['parind']
q = results['theta']
nsimu, npar = chain.shape
s2chain = check_s2chain(s2chain, nsimu)
iisample, nsample = define_sample_points(nsample, nsimu)
if waitbar is True:
__wbarstatus = progress_bar(iters=int(nsample))
ci = []
pi = []
multiple = False
for kk, isa in enumerate(iisample):
# progress bar
if waitbar is True:
__wbarstatus.update(kk)
# extract chain set
q[parind] = chain[kk, :]
# evaluate model
y = model(q, data)
# check model output
if y.ndim == 2:
nrow, ncol = y.shape
if nrow != y.size and ncol != y.size:
multiple = True
if multiple is False:
# store model prediction in credible intervals
ci.append(y.reshape(y.size,)) # store model output
if s2chain is None:
continue
else:
# estimate prediction intervals
s2 = s2chain[kk]
obs = observation_sample(s2, y, sstype)
pi.append(obs.reshape(obs.size,))
else:
# Model output contains multiple QoI
# Expect ncol = No. of QoI
if kk == 0:
cis = []
pis = []
for jj in range(ncol):
cis.append([])
pis.append([])
for jj in range(ncol):
# store model prediction in credible intervals
cis[jj].append(y[:, jj]) # store model output
if s2chain is None:
continue
else:
# estimate prediction intervals
if s2chain.ndim == 2:
if s2chain.shape[1] == ncol:
s2 = s2chain[kk, jj]
else:
s2 = s2chain[kk]
else:
s2 = s2chain[kk]
obs = observation_sample(s2, y[:, jj], sstype)
pis[jj].append(obs.reshape(obs.size,))
if multiple is False:
# Setup output
credible = np.array(ci)
if s2chain is None:
prediction = None
else:
prediction = np.array(pi)
return dict(credible=credible,
prediction=prediction)
else:
# Setup output for multiple QoI
out = []
for jj in range(ncol):
credible = np.array(cis[jj])
if s2chain is None:
prediction = None
else:
prediction = np.array(pis[jj])
out.append(dict(credible=credible,
prediction=prediction))
return out
# --------------------------------------------
[docs]def plot_intervals(intervals, time, ydata=None, xdata=None,
limits=[95],
adddata=None, addmodel=True, addlegend=True,
addcredible=True, addprediction=True,
data_display={}, model_display={}, interval_display={},
fig=None, figsize=None, legloc='upper left',
ciset=None, piset=None,
return_settings=False):
'''
Plot propagation intervals in 2-D
This routine takes the model distributions generated using the
:func:`~calculate_intervals` method and then plots specific
quantiles. The user can plot just the intervals, or also include the
median model response and/or observations. Specific settings for
credible intervals are controlled by defining the `ciset` dictionary.
Likewise, for prediction intervals, settings are defined using `piset`.
The setting options available for each interval are as follows:
- `limits`: This should be a list of numbers between 0 and 100, e.g.,
`limits=[50, 90]` will result in 50% and 90% intervals.
- `cmap`: The program is designed to "try" to choose colors that
are visually distinct. The user can specify the colormap to choose
from.
- `colors`: The user can specify the color they would like for each
interval in a list, e.g., ['r', 'g', 'b']. This list should have
the same number of elements as `limits` or the code will revert
back to its default behavior.
Args:
* **intervals** (:py:class:`dict`): Interval dictionary generated
using :meth:`calculate_intervals` method.
* **time** (:class:`~numpy.ndarray`): Independent variable, i.e.,
x-axis of plot
Kwargs:
* **ydata** (:class:`~numpy.ndarray` or None): Observations, expect
1-D array if defined.
* **xdata** (:class:`~numpy.ndarray` or None): Independent values
corresponding to observations. This is required if the observations
do not align with your times of generating the model response.
* **limits** (:py:class:`list`): Quantile limits that correspond to
percentage size of desired intervals. Note, this is the default
limits, but specific limits can be defined using the `ciset` and
`piset` dictionaries.
* **adddata** (:py:class:`bool`): Flag to include data
* **addmodel** (:py:class:`bool`): Flag to include median model
response
* **addlegend** (:py:class:`bool`): Flag to include legend
* **addcredible** (:py:class:`bool`): Flag to include credible
intervals
* **addprediction** (:py:class:`bool`): Flag to include prediction
intervals
* **model_display** (:py:class:`dict`): Display settings for median
model response
* **data_display** (:py:class:`dict`): Display settings for data
* **interval_display** (:py:class:`dict`): General display settings
for intervals.
* **fig**: Handle of previously created figure object
* **figsize** (:py:class:`tuple`): (width, height) in inches
* **legloc** (:py:class:`str`): Legend location - matplotlib help for
details.
* **ciset** (:py:class:`dict`): Settings for credible intervals
* **piset** (:py:class:`dict`): Settings for prediction intervals
* **return_settings** (:py:class:`bool`): Flag to return ciset and
piset along with fig and ax.
Returns:
* (:py:class:`tuple`) with elements
1) Figure handle
2) Axes handle
3) Dictionary with `ciset` and `piset` inside (only
outputted if `return_settings=True`)
'''
# unpack dictionary
credible = intervals['credible']
prediction = intervals['prediction']
# Check user-defined settings
ciset = __setup_iset(ciset,
default_iset=dict(
limits=limits,
cmap=None,
colors=None))
piset = __setup_iset(piset,
default_iset=dict(
limits=limits,
cmap=None,
colors=None))
# Check limits
ciset['limits'] = _check_limits(ciset['limits'], limits)
piset['limits'] = _check_limits(piset['limits'], limits)
# convert limits to ranges
ciset['quantiles'] = _convert_limits(ciset['limits'])
piset['quantiles'] = _convert_limits(piset['limits'])
# setup display settings
interval_display, model_display, data_display = setup_display_settings(
interval_display, model_display, data_display)
# Define colors
ciset['colors'] = setup_interval_colors(ciset, inttype='ci')
piset['colors'] = setup_interval_colors(piset, inttype='pi')
# Define labels
ciset['labels'] = _setup_labels(ciset['limits'], inttype='CI')
piset['labels'] = _setup_labels(piset['limits'], inttype='PI')
if fig is None:
fig = plt.figure(figsize=figsize)
ax = fig.gca()
time = time.reshape(time.size,)
# add prediction intervals
if addprediction is True:
for ii, quantile in enumerate(piset['quantiles']):
pi = generate_quantiles(prediction, np.array(quantile))
ax.fill_between(time, pi[0], pi[1], facecolor=piset['colors'][ii],
label=piset['labels'][ii], **interval_display)
# add credible intervals
if addcredible is True:
for ii, quantile in enumerate(ciset['quantiles']):
ci = generate_quantiles(credible, np.array(quantile))
ax.fill_between(time, ci[0], ci[1], facecolor=ciset['colors'][ii],
label=ciset['labels'][ii], **interval_display)
# add model (median model response)
if addmodel is True:
ci = generate_quantiles(credible, np.array(0.5))
ax.plot(time, ci, **model_display)
# add data to plot
if ydata is not None and adddata is None:
adddata = True
if adddata is True and ydata is not None:
if xdata is None:
ax.plot(time, ydata, **data_display)
else:
ax.plot(xdata, ydata, **data_display)
# add legend
if addlegend is True:
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, loc=legloc)
if return_settings is True:
return fig, ax, dict(ciset=ciset, piset=piset)
else:
return fig, ax
# --------------------------------------------
[docs]def plot_3d_intervals(intervals, time, ydata=None, xdata=None,
limits=[95],
adddata=False, addlegend=True,
addmodel=True, figsize=None, model_display={},
data_display={}, interval_display={},
addcredible=True, addprediction=True,
fig=None, legloc='upper left',
ciset=None, piset=None,
return_settings=False):
'''
Plot propagation intervals in 3-D
This routine takes the model distributions generated using the
:func:`~calculate_intervals` method and then plots specific
quantiles. The user can plot just the intervals, or also include the
median model response and/or observations. Specific settings for
credible intervals are controlled by defining the `ciset` dictionary.
Likewise, for prediction intervals, settings are defined using `piset`.
The setting options available for each interval are as follows:
- `limits`: This should be a list of numbers between 0 and 100, e.g.,
`limits=[50, 90]` will result in 50% and 90% intervals.
- `cmap`: The program is designed to "try" to choose colors that
are visually distinct. The user can specify the colormap to choose
from.
- `colors`: The user can specify the color they would like for each
interval in a list, e.g., ['r', 'g', 'b']. This list should have
the same number of elements as `limits` or the code will revert
back to its default behavior.
Args:
* **intervals** (:py:class:`dict`): Interval dictionary generated
using :meth:`calculate_intervals` method.
* **time** (:class:`~numpy.ndarray`): Independent variable, i.e.,
x- and y-axes of plot. Note, it must be a 2-D array with
shape=(N, 2), where N is the number of evaluation points.
Kwargs:
* **ydata** (:class:`~numpy.ndarray` or None): Observations, expect
1-D array if defined.
* **xdata** (:class:`~numpy.ndarray` or None): Independent values
corresponding to observations. This is required if the observations
do not align with your times of generating the model response.
* **limits** (:py:class:`list`): Quantile limits that correspond to
percentage size of desired intervals. Note, this is the default
limits, but specific limits can be defined using the `ciset` and
`piset` dictionaries.
* **adddata** (:py:class:`bool`): Flag to include data
* **addmodel** (:py:class:`bool`): Flag to include median model
response
* **addlegend** (:py:class:`bool`): Flag to include legend
* **addcredible** (:py:class:`bool`): Flag to include credible
intervals
* **addprediction** (:py:class:`bool`): Flag to include prediction
intervals
* **model_display** (:py:class:`dict`): Display settings for median
model response
* **data_display** (:py:class:`dict`): Display settings for data
* **interval_display** (:py:class:`dict`): General display settings
for intervals.
* **fig**: Handle of previously created figure object
* **figsize** (:py:class:`tuple`): (width, height) in inches
* **legloc** (:py:class:`str`): Legend location - matplotlib help for
details.
* **ciset** (:py:class:`dict`): Settings for credible intervals
* **piset** (:py:class:`dict`): Settings for prediction intervals
* **return_settings** (:py:class:`bool`): Flag to return ciset and
piset along with fig and ax.
Returns:
* (:py:class:`tuple`) with elements
1) Figure handle
2) Axes handle
3) Dictionary with `ciset` and `piset` inside (only
outputted if `return_settings=True`)
'''
# unpack dictionary
credible = intervals['credible']
prediction = intervals['prediction']
# Check user-defined settings
ciset = __setup_iset(ciset,
default_iset=dict(
limits=limits,
cmap=None,
colors=None))
piset = __setup_iset(piset,
default_iset=dict(
limits=limits,
cmap=None,
colors=None))
# Check limits
ciset['limits'] = _check_limits(ciset['limits'], limits)
piset['limits'] = _check_limits(piset['limits'], limits)
# convert limits to ranges
ciset['quantiles'] = _convert_limits(ciset['limits'])
piset['quantiles'] = _convert_limits(piset['limits'])
# setup display settings
interval_display, model_display, data_display = setup_display_settings(
interval_display, model_display, data_display)
# Define colors
ciset['colors'] = setup_interval_colors(ciset, inttype='ci')
piset['colors'] = setup_interval_colors(piset, inttype='pi')
# Define labels
ciset['labels'] = _setup_labels(ciset['limits'], inttype='CI')
piset['labels'] = _setup_labels(piset['limits'], inttype='PI')
if fig is None:
fig = plt.figure(figsize=figsize)
ax = Axes3D(fig)
ax = fig.gca()
time1 = time[:, 0]
time2 = time[:, 1]
# add prediction intervals
if addprediction is True:
for ii, quantile in enumerate(piset['quantiles']):
pi = generate_quantiles(prediction, np.array(quantile))
# Add a polygon instead of fill_between
rev = np.arange(time1.size - 1, -1, -1)
x = np.concatenate((time1, time1[rev]))
y = np.concatenate((time2, time2[rev]))
z = np.concatenate((pi[0], pi[1][rev]))
verts = [list(zip(x, y, z))]
surf = Poly3DCollection(verts,
color=piset['colors'][ii],
label=piset['labels'][ii])
# Add fix for legend compatibility
surf._facecolors2d = surf._facecolors3d
surf._edgecolors2d = surf._edgecolors3d
ax.add_collection3d(surf)
# add credible intervals
if addcredible is True:
for ii, quantile in enumerate(ciset['quantiles']):
ci = generate_quantiles(credible, np.array(quantile))
# Add a polygon instead of fill_between
rev = np.arange(time1.size - 1, -1, -1)
x = np.concatenate((time1, time1[rev]))
y = np.concatenate((time2, time2[rev]))
z = np.concatenate((ci[0], ci[1][rev]))
verts = [list(zip(x, y, z))]
surf = Poly3DCollection(verts,
color=ciset['colors'][ii],
label=ciset['labels'][ii])
# Add fix for legend compatibility
surf._facecolors2d = surf._facecolors3d
surf._edgecolors2d = surf._edgecolors3d
ax.add_collection3d(surf)
# add model (median model response)
if addmodel is True:
ci = generate_quantiles(credible, np.array(0.5))
ax.plot(time1, time2, ci, **model_display)
# add data to plot
if ydata is not None and adddata is None:
adddata = True
if adddata is True:
if xdata is None:
ax.plot(time1, time2, ydata.reshape(time1.shape), **data_display)
else: # User provided xdata array for observation points
ax.plot(xdata[:, 0], xdata[:, 1],
ydata.reshape(time1.shape), **data_display)
# add legend
if addlegend is True:
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, loc=legloc)
if return_settings is True:
return fig, ax, dict(ciset=ciset, piset=piset)
else:
return fig, ax
[docs]def check_s2chain(s2chain, nsimu):
'''
Check size of s2chain
Args:
* **s2chain** (:py:class:`float`, :class:`~numpy.ndarray`, or `None`):
Observation error variance chain or value
* **nsimu** (:py:class:`int`): No. of elements in chain
Returns:
* **s2chain** (:class:`~numpy.ndarray` or `None`)
'''
if s2chain is None:
return None
else:
if isinstance(s2chain, float):
s2chain = np.ones((nsimu,))*s2chain
if s2chain.ndim == 2:
if s2chain.shape[0] != nsimu:
s2chain = s2chain * np.ones((nsimu, s2chain.size))
else:
if s2chain.size != nsimu: # scalars provided for multiple QoI
s2chain = s2chain * np.ones((nsimu, s2chain.size))
return s2chain
# --------------------------------------------
[docs]def observation_sample(s2, y, sstype):
'''
Calculate model response with observation errors.
Args:
* **s2** (:class:`~numpy.ndarray`): Observation error(s).
* **y** (:class:`~numpy.ndarray`): Model responses.
* **sstype** (:py:class:`int`): Flag to specify sstype.
Returns:
* **opred** (:class:`~numpy.ndarray`): Model responses with observation errors.
'''
if sstype == 0:
opred = y + np.random.standard_normal(y.shape) * np.sqrt(s2)
elif sstype == 1: # sqrt
opred = (np.sqrt(y) + np.random.standard_normal(y.shape) * np.sqrt(s2))**2
elif sstype == 2: # log
opred = y*np.exp(np.random.standard_normal(y.shape) * np.sqrt(s2))
else:
sys.exit('Unknown sstype')
return opred
# --------------------------------------------
[docs]def define_sample_points(nsample, nsimu):
'''
Define indices to sample from posteriors.
Args:
* **nsample** (:py:class:`int`): Number of samples to draw from posterior.
* **nsimu** (:py:class:`int`): Number of MCMC simulations.
Returns:
* **iisample** (:class:`~numpy.ndarray`): Array of indices in posterior set.
* **nsample** (:py:class:`int`): Number of samples to draw from posterior.
'''
# define sample points
if nsample >= nsimu:
iisample = range(nsimu) # sample all points from chain
nsample = nsimu
else:
# randomly sample from chain
iisample = np.ceil(np.random.rand(nsample)*nsimu) - 1
iisample = iisample.astype(int)
return iisample, nsample
# --------------------------------------------
[docs]def generate_quantiles(x, p=np.array([0.25, 0.5, 0.75])):
'''
Calculate empirical quantiles.
Args:
* **x** (:class:`~numpy.ndarray`): Observations from which to generate quantile.
* **p** (:class:`~numpy.ndarray`): Quantile limits.
Returns:
* (:class:`~numpy.ndarray`): Interpolated quantiles.
'''
# extract number of rows/cols from np.array
n = x.shape[0]
# define vector valued interpolation function
xpoints = np.arange(0, n, 1)
interpfun = interp1d(xpoints, np.sort(x, 0), axis=0)
# evaluation points
itpoints = (n - 1)*p
return interpfun(itpoints)
[docs]def setup_display_settings(interval_display, model_display, data_display):
'''
Compare user defined display settings with defaults and merge.
Args:
* **interval_display** (:py:class:`dict`): User defined settings for interval display.
* **model_display** (:py:class:`dict`): User defined settings for model display.
* **data_display** (:py:class:`dict`): User defined settings for data display.
Returns:
* **interval_display** (:py:class:`dict`): Settings for interval display.
* **model_display** (:py:class:`dict`): Settings for model display.
* **data_display** (:py:class:`dict`): Settings for data display.
'''
# Setup interval display
default_interval_display = dict(
linestyle=':',
linewidth=1,
alpha=1.0,
edgecolor='k')
interval_display = check_settings(default_interval_display, interval_display)
# Setup model display
default_model_display = dict(
linestyle='-',
color='k',
marker='',
linewidth=2,
markersize=5,
label='Model')
model_display = check_settings(default_model_display, model_display)
# Setup data display
default_data_display = dict(
linestyle='',
color='b',
marker='.',
linewidth=1,
markersize=5,
label='Data')
data_display = check_settings(default_data_display, data_display)
return interval_display, model_display, data_display
[docs]def setup_interval_colors(iset, inttype='CI'):
'''
Setup colors for empirical intervals
This routine attempts to distribute the color of the UQ intervals
based on a normalize color map. Or, it will assign user-defined
colors; however, this only happens if the correct number of colors
are specified.
Args:
* **iset** (:py:class:`dict`): This dictionary should contain the
following keys - `limits`, `cmap`, and `colors`.
Kwargs:
* **inttype** (:py:class:`str`): Type of uncertainty interval
Returns:
* **ic** (:py:class:`list`): List containing color for each interval
'''
limits, cmap, colors = iset['limits'], iset['cmap'], iset['colors']
norm = __setup_cmap_norm(limits)
cmap = __setup_default_cmap(cmap, inttype)
# assign colors using color map or using colors defined by user
ic = []
if colors is None: # No user defined colors
for limits in limits:
ic.append(cmap(norm(limits)))
else:
if len(colors) == len(limits): # correct number of colors defined
for color in colors:
ic.append(color)
else: # User defined the wrong number of colors
print('Note, user-defined colors were ignored. Using color map. '
+ 'Expected a list of length {}, but received {}'.format(
len(limits), len(colors)))
for limits in limits:
ic.append(cmap(norm(limits)))
return ic
# --------------------------------------------
def _setup_labels(limits, inttype='CI'):
'''
Setup labels for prediction/credible intervals.
'''
labels = []
for limit in limits:
labels.append(str('{}% {}'.format(limit, inttype)))
return labels
def _check_limits(limits, default_limits):
if limits is None:
limits = default_limits
limits.sort(reverse=True)
return limits
def _convert_limits(limits):
rng = []
for limit in limits:
limit = limit/100
rng.append([0.5 - limit/2, 0.5 + limit/2])
return rng
def __setup_iset(iset, default_iset):
'''
Setup interval settings by comparing user input to default
'''
if iset is None:
iset = {}
iset = check_settings(default_iset, iset)
return iset
def __setup_cmap_norm(limits):
if len(limits) == 1:
norm = mplcolor.Normalize(vmin=0, vmax=100)
else:
norm = mplcolor.Normalize(vmin=min(limits), vmax=max(limits))
return norm
def __setup_default_cmap(cmap, inttype):
if cmap is None:
if inttype.upper() == 'CI':
cmap = cm.autumn
else:
cmap = cm.winter
return cmap