Source code for elfi.methods.results

"""Containers for results from inference."""

import io
import itertools
import logging
import os
import string
import sys
from collections import OrderedDict

import arviz as az
import matplotlib.pyplot as plt
import numpy as np

import elfi.visualization.visualization as vis
from elfi.methods.mcmc import eff_sample_size
from elfi.methods.utils import (numpy_to_python_type, sample_object_to_dict,
                                weighted_sample_quantile)

logger = logging.getLogger(__name__)


class ParameterInferenceResult:
    """Base class for results."""

    def __init__(self, method_name, outputs, parameter_names, **kwargs):
        """Initialize result.

        Parameters
        ----------
        method_name : string
            Name of inference method.
        outputs : dict
            Dictionary with outputs from the nodes, e.g. samples.
        parameter_names : list
            Names of the parameter nodes
        **kwargs
            Any other information from the inference algorithm, usually from its state.

        """
        self.method_name = method_name
        self.outputs = outputs.copy()
        self.parameter_names = parameter_names
        self.meta = kwargs

    @property
    def is_multivariate(self):
        """Check whether the result contains multivariate parameters."""
        for p in self.parameter_names:
            if self.outputs[p].ndim > 1:
                return True
        return False


[docs]class OptimizationResult(ParameterInferenceResult): """Base class for results from optimization.""" def __init__(self, x_min, **kwargs): """Initialize result. Parameters ---------- x_min The optimized parameters **kwargs See `ParameterInferenceResult` """ super(OptimizationResult, self).__init__(**kwargs) self.x_min = x_min
[docs]class Sample(ParameterInferenceResult): """Sampling results from inference methods.""" def __init__(self, method_name, outputs, parameter_names, discrepancy_name=None, weights=None, **kwargs): """Initialize result. Parameters ---------- method_name : string Name of inference method. outputs : dict Dictionary with outputs from the nodes, e.g. samples. parameter_names : list Names of the parameter nodes discrepancy_name : string, optional Name of the discrepancy in outputs. weights : array_like **kwargs Other meta information for the result """ super(Sample, self).__init__( method_name=method_name, outputs=outputs, parameter_names=parameter_names, **kwargs) self.samples = OrderedDict() for n in self.parameter_names: self.samples[n] = self.outputs[n] self.discrepancy_name = discrepancy_name self.weights = weights def __getattr__(self, item): """Allow more convenient access to items under self.meta.""" if item in self.meta.keys(): return self.meta[item] else: raise AttributeError("No attribute '{}' in this sample".format(item)) def __dir__(self): """Allow autocompletion for items under self.meta. http://stackoverflow.com/questions/13603088/python-dynamic-help-and-autocomplete-generation """ items = dir(type(self)) + list(self.__dict__.keys()) items.extend(self.meta.keys()) return items @property def n_samples(self): """Return the number of samples.""" return len(self.outputs[self.parameter_names[0]]) @property def dim(self): """Return the number of parameters.""" return len(self.parameter_names) @property def discrepancies(self): """Return the discrepancy values.""" return None if self.discrepancy_name is None else \ self.outputs[self.discrepancy_name] @property def samples_array(self): """Return the samples as an array. The columns are in the same order as in self.parameter_names. Returns ------- list of np.arrays """ return np.column_stack(tuple(self.samples.values())) def __str__(self): """Return a summary of results as a string.""" # create a buffer for capturing the output from summary's print statement stdout0 = sys.stdout buffer = io.StringIO() sys.stdout = buffer self.summary() sys.stdout = stdout0 # revert to original stdout return buffer.getvalue() def __repr__(self): """Return a summary of results as a string.""" return self.__str__()
[docs] def summary(self): """Print a verbose summary of contained results.""" # TODO: include __str__ of Inference Task, seed? desc = "Method: {}\nNumber of samples: {}\n" \ .format(self.method_name, self.n_samples) if hasattr(self, 'n_sim'): desc += "Number of simulations: {}\n".format(self.n_sim) if hasattr(self, 'threshold'): desc += "Threshold: {:.3g}\n".format(self.threshold) if hasattr(self, 'acc_rate'): desc += "MCMC Acceptance Rate: {:.3g}\n".format(self.acc_rate) print(desc, end='') try: self.sample_summary() except TypeError: pass
[docs] def sample_means_summary(self): """Print a representation of sample means.""" s = "Sample means: " s += ', '.join(["{}: {:.3g}".format(k, v) for k, v in self.sample_means.items()]) print(s)
[docs] def sample_summary(self): """Print sample mean and 95% credible interval.""" print("{0:24} {1:18} {2:17} {3:5}".format("Parameter", "Mean", "2.5%", "97.5%")) print(''.join([ "{0:10} " "{1:18.3f} " "{2:18.3f} " "{3:18.3f}\n" .format(k[:10] + ":", v[0], v[1], v[2]) for k, v in self.sample_means_and_95CIs.items()]))
@property def sample_means_and_95CIs(self): """Construct OrderedDict for mean and 95% credible interval.""" return OrderedDict( [(k, (np.average(v, axis=0, weights=self.weights), weighted_sample_quantile(v, alpha=0.025, weights=self.weights), weighted_sample_quantile(v, alpha=0.975, weights=self.weights))) for k, v in self.samples.items()] ) @property # Convert and return posterior sample to arviz InferenceData object def idata(self): """Convert posterior sample to arviz InferenceData object.""" return az.from_dict(posterior=self.samples) @property def sample_means(self): """Evaluate weighted averages of sampled parameters. Returns ------- OrderedDict """ return OrderedDict([(k, np.average(v, axis=0, weights=self.weights)) for k, v in self.samples.items()])
[docs] def get_sample_covariance(self): """Return covariance of samples.""" vals = np.array(list(self.samples.values())) cov_mat = np.cov(vals) return cov_mat
[docs] def sample_quantiles(self, alpha=0.5): """Evaluate weighted sample quantiles of sampled parameters.""" return OrderedDict([(k, weighted_sample_quantile(v, alpha=alpha, weights=self.weights)) for k, v in self.samples.items()])
@property def sample_means_array(self): """Evaluate weighted averages of sampled parameters. Returns ------- np.array """ return np.array(list(self.sample_means.values())) def __getstate__(self): """Says to pickle the exact objects to pickle.""" return self.meta, self.__dict__ def __setstate__(self, state): """Says to pickle which objects to unpickle.""" self.meta, self.__dict__ = state
[docs] def save(self, fname=None): """Save samples in csv, json or pickle file formats. Clarification: csv saves only samples, json saves the whole object's dictionary except `outputs` key and pickle saves the whole object. Parameters ---------- fname : str, required File name to be saved. The type is inferred from extension ('csv', 'json' or 'pkl'). """ import csv import json import pickle kind = os.path.splitext(fname)[1][1:] if kind == 'csv': with open(fname, 'w', newline='') as f: w = csv.writer(f) w.writerow(self.samples.keys()) w.writerows(itertools.zip_longest(*self.samples.values(), fillvalue='')) elif kind == 'json': with open(fname, 'w') as f: data = OrderedDict() data['n_samples'] = self.n_samples data['discrepancies'] = self.discrepancies data['dim'] = self.dim # populations key exists in SMC-ABC sampler and contains the history of all # inferences with different number of simulations and thresholds populations = 'populations' if populations in self.__dict__: # setting populations in the following form: # data = {'populations': {'A': dict(), 'B': dict()}, ...} # this helps to save all kind of populations pop_num = string.ascii_letters.upper()[:len(self.__dict__[populations])] data[populations] = OrderedDict() for n, elem in enumerate(self.__dict__[populations]): data[populations][pop_num[n]] = OrderedDict() sample_object_to_dict(data[populations][pop_num[n]], elem) # convert numpy types into python types in populations key for key, val in data[populations].items(): numpy_to_python_type(val) # skip populations because it was processed previously sample_object_to_dict(data, self, skip='populations') # convert numpy types into python types numpy_to_python_type(data) js = json.dumps(data) f.write(js) elif kind == 'pkl': with open(fname, 'wb') as f: pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) else: print("Wrong file type format. Please use 'csv', 'json' or 'pkl'.")
[docs] def plot_marginals(self, selector=None, bins=20, axes=None, reference_value=None, **kwargs): """Plot marginal distributions for parameters. Supports only univariate distributions. Parameters ---------- selector : iterable of ints or strings, optional Indices or keys to use from samples. Default to all. bins : int, optional Number of bins in histograms. axes : one or an iterable of plt.Axes, optional Returns ------- axes : np.array of plt.Axes """ if self.is_multivariate: print("Plotting multivariate distributions is unsupported.") else: return vis.plot_marginals( samples=self.samples, selector=selector, bins=bins, axes=axes, reference_value=reference_value, **kwargs)
[docs] def plot_pairs(self, selector=None, bins=20, axes=None, reference_value=None, draw_upper_triagonal=False, **kwargs): """Plot pairwise relationships as a matrix with marginals on the diagonal. The y-axis of marginal histograms are scaled. Supports only univariate distributions. Parameters ---------- selector : iterable of ints or strings, optional Indices or keys to use from samples. Default to all. bins : int, optional Number of bins in histograms. axes : one or an iterable of plt.Axes, optional Returns ------- axes : np.array of plt.Axes """ if self.is_multivariate: print("Plotting multivariate distributions is unsupported.") else: return vis.plot_pairs( samples=self.samples, selector=selector, bins=bins, reference_value=reference_value, axes=axes, draw_upper_triagonal=draw_upper_triagonal, **kwargs)
[docs]class SmcSample(Sample): """Container for results from SMC-ABC.""" def __init__(self, method_name, outputs, parameter_names, populations, *args, **kwargs): """Initialize result. Parameters ---------- method_name : str outputs : dict parameter_names : list populations : list[Sample] List of Sample objects args kwargs """ super(SmcSample, self).__init__( method_name=method_name, outputs=outputs, parameter_names=parameter_names, *args, **kwargs) self.populations = populations if self.weights is None: raise ValueError("No weights provided for the sample") @property def n_populations(self): """Return the number of populations.""" return len(self.populations)
[docs] def summary(self, all=False): """Print a verbose summary of contained results. Parameters ---------- all : bool, optional Whether to print the summary for all populations separately, or just the final population (default). """ super(SmcSample, self).summary() if all: for i, pop in enumerate(self.populations): print('\nPopulation {}:'.format(i)) pop.summary()
[docs] def sample_means_summary(self, all=False): """Print a representation of sample means. Parameters ---------- all : bool, optional Whether to print the means for all populations separately, or just the final population (default). """ if all is False: super(SmcSample, self).sample_means_summary() return out = '' for i, pop in enumerate(self.populations): out += "Sample means for population {}: ".format(i) out += ', '.join(["{}: {:.3g}".format(k, v) for k, v in pop.sample_means.items()]) out += '\n' print(out)
[docs] def plot_marginals(self, selector=None, bins=20, axes=None, all=False, **kwargs): """Plot marginal distributions for parameters for all populations. Parameters ---------- selector : iterable of ints or strings, optional Indices or keys to use from samples. Default to all. bins : int, optional Number of bins in histograms. axes : one or an iterable of plt.Axes, optional all : bool, optional Plot the marginals of all populations """ if all is False: super(SmcSample, self).plot_marginals() return fontsize = kwargs.pop('fontsize', 13) for i, pop in enumerate(self.populations): pop.plot_marginals(selector=selector, bins=bins, axes=axes) plt.suptitle("Population {}".format(i), fontsize=fontsize)
[docs] def plot_pairs(self, selector=None, bins=20, axes=None, all=False, **kwargs): """Plot pairwise relationships as a matrix with marginals on the diagonal. The y-axis of marginal histograms are scaled. Parameters ---------- selector : iterable of ints or strings, optional Indices or keys to use from samples. Default to all. bins : int, optional Number of bins in histograms. axes : one or an iterable of plt.Axes, optional all : bool, optional Plot for all populations """ if all is False: super(SmcSample, self).plot_marginals() return fontsize = kwargs.pop('fontsize', 13) for i, pop in enumerate(self.populations): pop.plot_pairs(selector=selector, bins=bins, axes=axes) plt.suptitle("Population {}".format(i), fontsize=fontsize)
[docs]class BolfiSample(Sample): """Container for results from BOLFI.""" def __init__(self, method_name, chains, parameter_names, warmup, **kwargs): """Initialize result. Parameters ---------- method_name : string Name of inference method. chains : np.array Chains from sampling, warmup included. Shape: (n_chains, n_samples, n_parameters). parameter_names : list : list of strings List of names in the outputs dict that refer to model parameters. warmup : int Number of warmup iterations in chains. """ chains = chains.copy() shape = chains.shape n_chains = shape[0] warmed_up = chains[:, warmup:, :] concatenated = warmed_up.reshape((-1,) + shape[2:]) outputs = dict(zip(parameter_names, concatenated.T)) super(BolfiSample, self).__init__( method_name=method_name, outputs=outputs, parameter_names=parameter_names, chains=chains, n_chains=n_chains, warmup=warmup, **kwargs)
[docs] def plot_traces(self, selector=None, axes=None, **kwargs): """Plot MCMC traces.""" return vis.plot_traces(self, selector, axes, **kwargs)
[docs]class BslSample(Sample): """Container for results from BSL.""" def __init__(self, method_name, samples_all, parameter_names, burn_in=0, acc_rate=None, **kwargs): """Initialize result. Parameters ---------- method_name : string Name of inference method. samples_all : np.ndarray Dictionary with all samples from the MCMC chain, burn in included. parameter_names : list Names of the parameter nodes burn_in : int Number of samples to discard from start of MCMC chain. acc_rate : float The acceptance rate of proposed parameters in the MCMC chain **kwargs Other meta information for the result """ outputs = {k: samples_all[k][burn_in:] for k in samples_all.keys()} super(BslSample, self).__init__( method_name=method_name, outputs=outputs, parameter_names=parameter_names, samples_all=samples_all, burn_in=burn_in, acc_rate=acc_rate, **kwargs)
[docs] def plot_traces(self, selector=None, axes=None, **kwargs): """Plot MCMC traces.""" # BSL only needs 1 chain... prep to use with traces (for BOLFI) code self.n_chains = 1 N_all = self.n_samples + self.burn_in k = self.dim self.warmup = self.burn_in # different name self.chains = np.zeros((1, N_all, k)) # chains x samples x params for ii, s in enumerate(self.parameter_names): self.chains[0, :, ii] = self.samples_all[s] return vis.plot_traces(self, selector, axes, **kwargs)
[docs] def compute_ess(self): """Compute the effective sample size of mcmc chain. Returns ------- dict Effective sample size for each paramter """ return {p: eff_sample_size(self.samples[p]) for p in self.parameter_names}
class BOLFIRESample(Sample): """Container for results from BOLFIRE.""" def __init__(self, method_name, chains, parameter_names, warmup, *args, **kwargs): """Initialize BOLFIRE result. Parameters ---------- method_name: str Name of the inference method. chains: np.ndarray (n_chains, n_samples, n_parameters) Chains from sampling, warmup included. parameter_names: list List of names in the outputs dict that refer to model parameters. warmup: int Number of warmup iterations in chains. """ n_chains = chains.shape[0] warmed_up = chains[:, warmup:, :] concatenated = warmed_up.reshape((-1,) + chains.shape[2:]) outputs = dict(zip(parameter_names, concatenated.T)) super(BOLFIRESample, self).__init__( method_name=method_name, outputs=outputs, parameter_names=parameter_names, chains=chains, n_chains=n_chains, warmup=warmed_up, *args, **kwargs )
[docs]class RomcSample(Sample): """Container for results from ROMC.""" def __init__(self, method_name, outputs, parameter_names, discrepancy_name, weights, **kwargs): """Class constructor. Parameters ---------- method_name: string Name of the inference method outputs: Dict Dict where key is the parameter name and value are the samples parameter_names: List[string] List of the parameter names discrepancy_name: string name of the output (=distance) node weights: np.ndarray the weights of the samples kwargs """ super(RomcSample, self).__init__( method_name, outputs, parameter_names, discrepancy_name=discrepancy_name, weights=weights, kwargs=kwargs)
[docs] def samples_cov(self): """Print the empirical covariance matrix. Returns ------- np.ndarray (D,D) the covariance matrix """ samples = self.samples_array weights = self.weights cov_mat = np.cov(samples, rowvar=False, aweights=weights) return cov_mat