"""This module includes common functions for visualization."""
from collections import OrderedDict
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as ss
from elfi.model.elfi_model import Constant, ElfiModel, NodeReference
def nx_draw(G, internal=False, param_names=False, filename=None, format=None):
"""Draw the `ElfiModel`.
Parameters
----------
G : nx.DiGraph or ElfiModel
Graph or model to draw
internal : boolean, optional
Whether to draw internal nodes (starting with an underscore)
param_names : bool, optional
Show param names on edges
filename : str, optional
If given, save the dot file into the given filename.
format : str, optional
format of the file
Notes
-----
Requires the optional 'graphviz' library.
Returns
-------
dot
A GraphViz dot representation of the model.
"""
try:
from graphviz import Digraph
except ImportError:
raise ImportError("The graphviz library is required for this feature.")
if isinstance(G, ElfiModel):
G = G.source_net
elif isinstance(G, NodeReference):
G = G.model.source_net
dot = Digraph(format=format)
hidden = set()
for n, state in G.nodes(data=True):
if not internal and n[0] == '_' and state['attr_dict'].get('_class') == Constant:
hidden.add(n)
continue
_format = {'shape': 'circle', 'fillcolor': 'gray80', 'style': 'solid'}
if state['attr_dict'].get('_observable'):
_format['style'] = 'filled'
dot.node(n, **_format)
# add edges to graph
for u, v, label in G.edges(data='param', default=''):
if not internal and u in hidden:
continue
label = label if param_names else ''
dot.edge(u, v, str(label))
if filename is not None:
dot.render(filename)
return dot
def _create_axes(axes, shape, **kwargs):
"""Check the axes and create them if necessary.
Parameters
----------
axes : plt.Axes or arraylike of plt.Axes
shape : tuple of int
(x,) or (x,y)
kwargs
Returns
-------
axes : np.array of plt.Axes
kwargs : dict
Input kwargs without items related to creating a figure.
"""
fig_kwargs = {}
kwargs['figsize'] = kwargs.get('figsize', (4 * shape[1], 4 * shape[0]))
for k in ['figsize', 'sharex', 'sharey', 'dpi', 'num']:
if k in kwargs.keys():
fig_kwargs[k] = kwargs.pop(k)
if axes is not None:
axes = np.atleast_2d(axes)
else:
fig, axes = plt.subplots(ncols=shape[1], nrows=shape[0], **fig_kwargs)
axes = np.reshape(axes, shape)
fig.tight_layout(pad=2.0, h_pad=1.08, w_pad=1.08)
fig.subplots_adjust(wspace=0.2, hspace=0.2)
return axes, kwargs
def _limit_params(samples, selector=None):
"""Pick only the selected parameters from all samples.
Parameters
----------
samples : OrderedDict of np.arrays
selector : iterable of ints or strings, optional
Indices or keys to use from samples. Default to all.
Returns
-------
selected : OrderedDict of np.arrays
"""
if selector is None:
return samples
else:
selected = OrderedDict()
for ii, k in enumerate(samples):
if ii in selector or k in selector:
selected[k] = samples[k]
return selected
def plot_marginals(samples, selector=None, bins=20, axes=None,
reference_value=None, **kwargs):
"""Plot marginal distributions for parameters.
Parameters
----------
samples : OrderedDict of np.arrays
selector : iterable of ints or strings, optional
Indices or keys to use from samples. Default to all.
bins : int, optional
Number of bins in histogram.
axes : one or an iterable of plt.Axes, optional
Returns
-------
axes : np.array of plt.Axes
"""
ncols = len(samples.keys()) if len(samples.keys()) < 5 else 5
ncols = kwargs.pop('ncols', ncols)
samples = _limit_params(samples, selector)
shape = (-(len(samples) // -ncols), min(len(samples), ncols))
axes, kwargs = _create_axes(axes, shape, **kwargs)
axes = axes.ravel()
for idx, key in enumerate(samples.keys()):
if reference_value is not None:
axes[idx].plot(reference_value[key], 0,
color='red',
alpha=1.0,
linewidth=2,
marker='X',
clip_on=False,
markersize=12)
if ('kde' in kwargs):
kde = ss.gaussian_kde(samples[key])
xs = np.linspace(min(samples[key]), max(samples[key]))
axes[idx].plot(xs, kde(xs))
else:
axes[idx].hist(samples[key], bins=bins, **kwargs)
axes[idx].set_xlabel(key)
for idx in range(len(samples), len(axes)):
axes[idx].set_axis_off()
return axes
def plot_pairs(samples,
selector=None,
bins=20,
reference_value=None,
axes=None,
draw_upper_triagonal=False,
**kwargs):
"""Plot pairwise relationships as a matrix with marginals on the diagonal.
The y-axis of marginal histograms are scaled.
Parameters
----------
samples : OrderedDict of np.arrays
selector : iterable of ints or strings, optional
Indices or keys to use from samples. Default to all.
bins : int, optional
Number of bins in histograms.
reference_value: dict, optional
Dictionary containing reference values for parameters.
axes : one or an iterable of plt.Axes, optional
draw_upper_triagonal: boolean, optional
Boolean indicating whether to draw symmetric upper triagonal part.
Returns
-------
axes : np.array of plt.Axes
"""
samples = _limit_params(samples, selector)
shape = (len(samples), len(samples))
edgecolor = kwargs.pop('edgecolor', 'black')
dot_size = kwargs.pop('s', 2)
axes, kwargs = _create_axes(axes, shape, **kwargs)
for idx_row, key_row in enumerate(samples):
min_samples = samples[key_row].min()
max_samples = samples[key_row].max()
for idx_col, key_col in enumerate(samples):
if idx_row == idx_col:
axes[idx_row, idx_col].hist(samples[key_row], bins=bins, density=True, **kwargs)
if reference_value is not None:
axes[idx_row, idx_col].plot(
reference_value[key_row], 0,
color='red',
alpha=1.0,
linewidth=2,
marker='X',
clip_on=False,
markersize=12)
axes[idx_row, idx_col].get_yaxis().set_ticklabels([])
axes[idx_row, idx_col].set(xlim=(min_samples, max_samples))
else:
if (idx_row > idx_col) or draw_upper_triagonal:
axes[idx_row, idx_col].plot(samples[key_col],
samples[key_row],
linestyle='',
marker='o',
alpha=0.6,
clip_on=False,
markersize=dot_size,
markeredgecolor=edgecolor,
**kwargs)
if reference_value is not None:
axes[idx_row, idx_col].plot(
[samples[key_col].min(), samples[key_col].max()],
[reference_value[key_row], reference_value[key_row]],
color='red', alpha=0.8, linewidth=2)
axes[idx_row, idx_col].plot(
[reference_value[key_col], reference_value[key_col]],
[samples[key_row].min(), samples[key_row].max()],
color='red', alpha=0.8, linewidth=2)
axes[idx_row, idx_col].axis([samples[key_col].min(),
samples[key_col].max(),
samples[key_row].min(),
samples[key_row].max()])
else:
if idx_row < idx_col:
axes[idx_row, idx_col].axis('off')
axes[idx_row, 0].set_ylabel(key_row)
axes[-1, idx_row].set_xlabel(key_row)
return axes
def plot_traces(result, selector=None, axes=None, **kwargs):
"""Trace plot for MCMC samples.
The black vertical lines indicate the used warmup.
Parameters
----------
result : Result_BOLFI
selector : iterable of ints or strings, optional
Indices or keys to use from samples. Default to all.
axes : one or an iterable of plt.Axes, optional
kwargs
Returns
-------
axes : np.array of plt.Axes
"""
samples_sel = _limit_params(result.samples, selector)
shape = (len(samples_sel), result.n_chains)
kwargs['sharex'] = 'all'
kwargs['sharey'] = 'row'
axes, kwargs = _create_axes(axes, shape, **kwargs)
i1 = 0
for i2, k in enumerate(result.samples):
if k in samples_sel:
for i3 in range(result.n_chains):
axes[i1, i3].plot(result.chains[i3, :, i2], **kwargs)
axes[i1, i3].axvline(result.warmup, color='black')
axes[i1, 0].set_ylabel(k)
i1 += 1
for ii in range(result.n_chains):
axes[-1, ii].set_xlabel('Iterations in Chain {}'.format(ii))
return axes
[docs]def plot_params_vs_node(node, n_samples=100, func=None, seed=None, axes=None, **kwargs):
"""Plot some realizations of parameters vs. `node`.
Useful e.g. for exploring how a summary statistic varies with parameters.
Currently only nodes with scalar output are supported, though a function `func` can
be given to reduce node output. This allows giving the simulator as the `node` and
applying a summarizing function without incorporating it into the ELFI graph.
If `node` is one of the model parameters, its histogram is plotted.
Parameters
----------
node : elfi.NodeReference
The node which to evaluate. Its output must be scalar (shape=(batch_size,1)).
n_samples : int, optional
How many samples to plot.
func : callable, optional
A function to apply to node output.
seed : int, optional
axes : one or an iterable of plt.Axes, optional
Returns
-------
axes : np.array of plt.Axes
"""
model = node.model
parameters = model.parameter_names
node_name = node.name
if node_name in parameters:
outputs = [node_name]
shape = (1, 1)
bins = kwargs.pop('bins', 20)
else:
outputs = parameters + [node_name]
n_params = len(parameters)
ncols = n_params if n_params < 5 else 5
ncols = kwargs.pop('ncols', ncols)
edgecolor = kwargs.pop('edgecolor', 'none')
dot_size = kwargs.pop('s', 20)
shape = (1 + n_params // (ncols + 1), ncols)
data = model.generate(batch_size=n_samples, outputs=outputs, seed=seed)
if func is not None:
if hasattr(func, '__name__'):
node_name = func.__name__
else:
node_name = 'func'
data[node_name] = func(data[node.name]) # leaves rest of the code unmodified
if data[node_name].shape != (n_samples,):
raise NotImplementedError("The plotted quantity must have shape ({},), was {}."
.format(n_samples, data[node_name].shape))
axes, kwargs = _create_axes(axes, shape, sharey=True, **kwargs)
axes = axes.ravel()
if len(outputs) == 1:
axes[0].hist(data[node_name], bins=bins, normed=True)
axes[0].set_xlabel(node_name)
else:
for idx, key in enumerate(parameters):
axes[idx].scatter(data[key],
data[node_name],
s=dot_size,
edgecolor=edgecolor,
**kwargs)
axes[idx].set_xlabel(key)
axes[0].set_ylabel(node_name)
for idx in range(len(parameters), len(axes)):
axes[idx].set_axis_off()
return axes
def plot_discrepancy(gp, parameter_names, axes=None, **kwargs):
"""Plot acquired parameters vs. resulting discrepancy.
Parameters
----------
axes : plt.Axes or arraylike of plt.Axes
gp : GPyRegression target model, required
parameter_names : dict, required
Parameter names from model.parameters dict('parameter_name':(lower, upper), ... )`
Returns
-------
axes : np.array of plt.Axes
"""
n_plots = gp.input_dim
ncols = len(gp.bounds) if len(gp.bounds) < 5 else 5
ncols = kwargs.pop('ncols', ncols)
kwargs['sharey'] = kwargs.get('sharey', True)
if n_plots > 10:
shape = (1 + (1 + n_plots) // (ncols + 1), ncols)
else:
shape = (1 + n_plots // (ncols + 1), ncols)
axes, kwargs = _create_axes(axes, shape, **kwargs)
axes = axes.ravel()
for ii in range(n_plots):
axes[ii].scatter(gp.X[:, ii], gp.Y[:, 0], **kwargs)
axes[ii].set_xlabel(parameter_names[ii])
if ii % ncols == 0:
axes[ii].set_ylabel('Discrepancy')
for idx in range(len(parameter_names), len(axes)):
axes[idx].set_axis_off()
return axes
def plot_gp(gp, parameter_names, axes=None, resol=50,
const=None, bounds=None, true_params=None, **kwargs):
"""Plot pairwise relationships as a matrix with parameters vs. discrepancy.
Parameters
----------
gp : GPyRegression, required
parameter_names : list, required
Parameter names in format ['mu_0', 'mu_1', ..]
axes : plt.Axes or arraylike of plt.Axes
resol : int, optional
Resolution of the plotted grid.
const : np.array, optional
Values for parameters in plots where held constant. Defaults to minimum evidence.
bounds: list of tuples, optional
List of tuples for axis boundaries.
true_params : dict, optional
Dictionary containing parameter names with corresponding true parameter values.
Returns
-------
axes : np.array of plt.Axes
"""
n_plots = gp.input_dim
shape = (n_plots, n_plots)
axes, kwargs = _create_axes(axes, shape, **kwargs)
x_evidence = gp.X
y_evidence = gp.Y
if const is None:
const = x_evidence[np.argmin(y_evidence), :]
bounds = bounds or gp.bounds
cmap = plt.cm.get_cmap("Blues")
for ix in range(n_plots):
for jy in range(n_plots):
if ix == jy:
axes[jy, ix].scatter(x_evidence[:, ix], y_evidence, edgecolors='black', alpha=0.6)
axes[jy, ix].get_yaxis().set_ticklabels([])
axes[jy, ix].yaxis.tick_right()
axes[jy, ix].set_ylabel('Discrepancy')
axes[jy, ix].yaxis.set_label_position("right")
if true_params is not None:
axes[jy, ix].plot([true_params[parameter_names[ix]],
true_params[parameter_names[ix]]],
[min(y_evidence), max(y_evidence)],
color='red', alpha=1.0, linewidth=1)
axes[jy, ix].axis([bounds[ix][0], bounds[ix][1], min(y_evidence), max(y_evidence)])
elif ix < jy:
x1 = np.linspace(bounds[ix][0], bounds[ix][1], resol)
y1 = np.linspace(bounds[jy][0], bounds[jy][1], resol)
x, y = np.meshgrid(x1, y1)
predictors = np.tile(const, (resol * resol, 1))
predictors[:, ix] = x.ravel()
predictors[:, jy] = y.ravel()
z = gp.predict_mean(predictors).reshape(resol, resol)
axes[jy, ix].contourf(x, y, z, cmap=cmap)
axes[jy, ix].scatter(x_evidence[:, ix],
x_evidence[:, jy],
color="red",
alpha=0.7,
s=5)
if true_params is not None:
axes[jy, ix].plot([true_params[parameter_names[ix]],
true_params[parameter_names[ix]]],
[bounds[jy][0], bounds[jy][1]],
color='red', alpha=1.0, linewidth=1)
axes[jy, ix].plot([bounds[ix][0], bounds[ix][1]],
[true_params[parameter_names[jy]],
true_params[parameter_names[jy]]],
color='red', alpha=1.0, linewidth=1)
if ix == 0:
axes[jy, ix].set_ylabel(parameter_names[jy])
else:
axes[jy, ix].get_yaxis().set_ticklabels([])
axes[jy, ix].axis([bounds[ix][0], bounds[ix][1], bounds[jy][0], bounds[jy][1]])
else:
axes[jy, ix].axis('off')
if jy < n_plots-1:
axes[jy, ix].get_xaxis().set_ticklabels([])
else:
axes[jy, ix].set_xlabel(parameter_names[ix])
return axes
def plot_predicted_summaries(model=None,
summary_names=None,
n_samples=100,
seed=None,
bins=20,
axes=None,
add_observed=True,
draw_upper_triagonal=False,
**kwargs):
"""Pairplots of 1D summary statistics calculated from prior predictive distribution.
Parameters
----------
model: elfi.Model
Model which is explored.
summary_names: list of strings
Summary statistics which are pairplotted.
n_samples: int, optional
How many samples are drawn from the model.
bins : int, optional
Number of bins in histograms.
axes : one or an iterable of plt.Axes, optional
add_observed: boolean, optional
Add observed summary points in pairplots
draw_upper_triagonal: boolean, optional
Boolean indicating whether to draw symmetric upper triagonal part.
"""
dot_size = kwargs.pop('s', 8)
samples = model.generate(batch_size=n_samples, outputs=summary_names, seed=seed)
reference_value = model.generate(with_values=model.observed, outputs=summary_names)
reference_value = reference_value if add_observed else None
plot_pairs(samples,
selector=None,
bins=bins,
axes=axes,
reference_value=reference_value,
s=dot_size,
draw_upper_triagonal=draw_upper_triagonal)
class ProgressBar:
"""Progress bar monitoring the inference process.
Attributes
----------
prefix : str, optional
Prefix string
suffix : str, optional
Suffix string
decimals : int, optional
Positive number of decimals in percent complete
length : int, optional
Character length of bar
fill : str, optional
Bar fill character
scaling : int, optional
Integer used to scale current iteration and total iterations of the progress bar
"""
def __init__(self, prefix='', suffix='', decimals=1, length=100, fill='='):
"""Construct progressbar for monitoring.
Parameters
----------
prefix : str, optional
Prefix string
suffix : str, optional
Suffix string
decimals : int, optional
Positive number of decimals in percent complete
length : int, optional
Character length of bar
fill : str, optional
Bar fill character
"""
self.prefix = prefix
self.suffix = suffix
self.decimals = 1
self.length = length
self.fill = fill
self.scaling = 0
self.finished = False
def update_progressbar(self, iteration, total):
"""Print updated progress bar in console.
Parameters
----------
iteration : int
Integer indicating completed iterations
total : int
Integer indicating total number of iterations
"""
if iteration >= total:
percent = ("{0:." + str(self.decimals) + "f}").\
format(100.0)
bar = self.fill * self.length
if not self.finished:
print('%s [%s] %s%% %s' % (self.prefix, bar, percent, self.suffix))
self.finished = True
elif total - self.scaling > 0:
percent = ("{0:." + str(self.decimals) + "f}").\
format(100 * ((iteration - self.scaling) / float(total - self.scaling)))
filled_length = int(self.length * (iteration - self.scaling) // (total - self.scaling))
bar = self.fill * filled_length + '-' * (self.length - filled_length)
print('%s [%s] %s%% %s' % (self.prefix, bar, percent, self.suffix), end='\r')
def reinit_progressbar(self, scaling=0, reinit_msg=""):
"""Reinitialize new round of progress bar.
Parameters
----------
scaling : int, optional
Integer used to scale current and total iterations of the progress bar
reinit_msg : str, optional
Message printed before restarting an empty progess bar on a new line
"""
self.scaling = scaling
self.finished = False
print(reinit_msg)