"""This module contains sampling based inference methods."""
__all__ = ['Rejection', 'SMC', 'AdaptiveDistanceSMC', 'AdaptiveThresholdSMC']
import logging
from math import ceil
import numpy as np
import elfi.visualization.interactive as visin
from elfi.loader import get_sub_seed
from elfi.methods.density_ratio_estimation import (DensityRatioEstimation,
calculate_densratio_basis_sigma)
from elfi.methods.inference.parameter_inference import ParameterInference
from elfi.methods.results import Sample, SmcSample
from elfi.methods.utils import (GMDistribution, arr2d_to_batch,
weighted_sample_quantile, weighted_var)
from elfi.model.elfi_model import AdaptiveDistance
from elfi.model.extensions import ModelPrior
from elfi.utils import is_array
logger = logging.getLogger(__name__)
class Sampler(ParameterInference):
def sample(self, n_samples, *args, **kwargs):
"""Sample from the approximate posterior.
See the other arguments from the `set_objective` method.
Parameters
----------
n_samples : int
Number of samples to generate from the (approximate) posterior
*args
**kwargs
Returns
-------
result : Sample
"""
bar = kwargs.pop('bar', True)
self.bar = bar
return self.infer(n_samples, *args, bar=bar, **kwargs)
def _extract_result_kwargs(self):
kwargs = super(Sampler, self)._extract_result_kwargs()
for state_key in ['threshold', 'accept_rate']:
if state_key in self.state:
kwargs[state_key] = self.state[state_key]
if hasattr(self, 'discrepancy_name'):
kwargs['discrepancy_name'] = self.discrepancy_name
return kwargs
[docs]class Rejection(Sampler):
"""Parallel ABC rejection sampler.
For a description of the rejection sampler and a general introduction to ABC, see e.g.
Lintusaari et al. 2016.
References
----------
Lintusaari J, Gutmann M U, Dutta R, Kaski S, Corander J (2016). Fundamentals and
Recent Developments in Approximate Bayesian Computation. Systematic Biology.
http://dx.doi.org/10.1093/sysbio/syw077.
"""
def __init__(self, model, discrepancy_name=None, output_names=None, **kwargs):
"""Initialize the Rejection sampler.
Parameters
----------
model : ElfiModel or NodeReference
discrepancy_name : str, NodeReference, optional
Only needed if model is an ElfiModel
output_names : list, optional
Additional outputs from the model to be included in the inference result, e.g.
corresponding summaries to the acquired samples
kwargs:
See ParameterInference
"""
model, discrepancy_name = self._resolve_model(model, discrepancy_name)
output_names = [discrepancy_name] + model.parameter_names + (output_names or [])
self.adaptive = isinstance(model[discrepancy_name], AdaptiveDistance)
if self.adaptive:
model[discrepancy_name].init_adaptation_round()
# Summaries are needed as adaptation data
self.sums = [sumstat.name for sumstat in model[discrepancy_name].parents]
for k in self.sums:
if k not in output_names:
output_names.append(k)
super(Rejection, self).__init__(model, output_names, **kwargs)
self.discrepancy_name = discrepancy_name
[docs] def set_objective(self, n_samples, threshold=None, quantile=None, n_sim=None):
"""Set objective for inference.
Parameters
----------
n_samples : int
number of samples to generate
threshold : float
Acceptance threshold
quantile : float
In between (0,1). Define the threshold as the p-quantile of all the
simulations. n_sim = n_samples/quantile.
n_sim : int
Total number of simulations. The threshold will be the n_samples-th smallest
discrepancy among n_sim simulations.
"""
if quantile is None and threshold is None and n_sim is None:
quantile = .01
self.state = dict(samples=None, threshold=np.Inf,
n_sim=0, accept_rate=1, n_batches=0)
if quantile:
n_sim = ceil(n_samples / quantile)
# Set initial n_batches estimate
if n_sim:
n_batches = ceil(n_sim / self.batch_size)
else:
n_batches = self.max_parallel_batches
self.objective = dict(n_samples=n_samples,
threshold=threshold, n_batches=n_batches)
# Reset the inference
self.batches.reset()
[docs] def update(self, batch, batch_index):
"""Update the inference state with a new batch.
Parameters
----------
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
batch_index : int
"""
super(Rejection, self).update(batch, batch_index)
if self.state['samples'] is None:
# Lazy initialization of the outputs dict
self._init_samples_lazy(batch)
self._merge_batch(batch)
self._update_state_meta()
self._update_objective_n_batches()
def _init_samples_lazy(self, batch):
"""Initialize the outputs dict based on the received batch."""
samples = {}
e_noarr = "Node {} output must be in a numpy array of length {} (batch_size)."
e_len = "Node {} output has array length {}. It should be equal to the batch size {}."
for node in self.output_names:
# Check the requested outputs
if node not in batch:
raise KeyError(
"Did not receive outputs for node {}".format(node))
nbatch = batch[node]
if not is_array(nbatch):
raise ValueError(e_noarr.format(node, self.batch_size))
elif len(nbatch) != self.batch_size:
raise ValueError(e_len.format(
node, len(nbatch), self.batch_size))
# Prepare samples
shape = (self.objective['n_samples'] +
self.batch_size, ) + nbatch.shape[1:]
dtype = nbatch.dtype
if node == self.discrepancy_name:
# Initialize the distances to inf
samples[node] = np.ones(shape, dtype=dtype) * np.inf
else:
samples[node] = np.empty(shape, dtype=dtype)
self.state['samples'] = samples
def _merge_batch(self, batch):
# TODO: add index vector so that you can recover the original order
samples = self.state['samples']
# Add current batch to adaptation data
if self.adaptive:
observed_sums = [batch[s] for s in self.sums]
self.model[self.discrepancy_name].add_data(*observed_sums)
# Check acceptance condition
if self.objective.get('threshold') is None:
accepted = slice(None, None)
num_accepted = self.batch_size
else:
accepted = batch[self.discrepancy_name] <= self.objective.get('threshold')
accepted = np.all(np.atleast_2d(np.transpose(accepted)), axis=0)
num_accepted = np.sum(accepted)
# Put the acquired samples to the end
if num_accepted > 0:
for node, v in samples.items():
v[-num_accepted:] = batch[node][accepted]
# Sort the smallest to the beginning
# note: last (-1) distance measure is used when distance calculation is nested
sort_distance = np.atleast_2d(np.transpose(samples[self.discrepancy_name]))[-1]
sort_mask = np.argsort(sort_distance)
for k, v in samples.items():
v[:] = v[sort_mask]
def _update_state_meta(self):
"""Update `n_sim`, `threshold`, and `accept_rate`."""
o = self.objective
s = self.state
s['threshold'] = s['samples'][self.discrepancy_name][o['n_samples'] - 1]
s['accept_rate'] = min(1, o['n_samples'] / s['n_sim'])
def _update_objective_n_batches(self):
# Only in the case that the threshold is used
if self.objective.get('threshold') is None:
return
s = self.state
t, n_samples = [self.objective.get(k)
for k in ('threshold', 'n_samples')]
# noinspection PyTypeChecker
if s['samples']:
accepted = s['samples'][self.discrepancy_name] <= t
n_acceptable = np.sum(np.all(np.atleast_2d(np.transpose(accepted)), axis=0))
else:
n_acceptable = 0
if n_acceptable == 0:
# No acceptable samples found yet, increase n_batches of objective by one in
# order to keep simulating
n_batches = self.objective['n_batches'] + 1
else:
accept_rate_t = n_acceptable / s['n_sim']
# Add some margin to estimated n_batches. One could also use confidence
# bounds here
margin = .2 * self.batch_size * int(n_acceptable < n_samples)
n_batches = (n_samples / accept_rate_t + margin) / self.batch_size
n_batches = ceil(n_batches)
self.objective['n_batches'] = n_batches
logger.debug('Estimated objective n_batches=%d' %
self.objective['n_batches'])
def _update_distances(self):
# Update adaptive distance node
self.model[self.discrepancy_name].update_distance()
# Recalculate distances in current sample
nums = self.objective['n_samples']
data = {s: self.state['samples'][s][:nums] for s in self.sums}
ds = self.model[self.discrepancy_name].generate(with_values=data)
# Sort based on new distance measure
sort_distance = np.atleast_2d(np.transpose(ds))[-1]
sort_mask = np.argsort(sort_distance)
# Update state
self.state['samples'][self.discrepancy_name] = sort_distance
for k in self.state['samples'].keys():
if k != self.discrepancy_name:
self.state['samples'][k][:nums] = self.state['samples'][k][sort_mask]
self._update_state_meta()
[docs] def plot_state(self, **options):
"""Plot the current state of the inference algorithm.
This feature is still experimental and only supports 1d or 2d cases.
"""
displays = []
if options.get('interactive'):
from IPython import display
displays.append(
display.HTML('<span>Threshold: {}</span>'.format(self.state['threshold'])))
visin.plot_sample(
self.state['samples'],
nodes=self.parameter_names,
n=self.objective['n_samples'],
displays=displays,
**options)
[docs]class SMC(Sampler):
"""Sequential Monte Carlo ABC sampler."""
def __init__(self, model, discrepancy_name=None, output_names=None, **kwargs):
"""Initialize the SMC-ABC sampler.
Parameters
----------
model : ElfiModel or NodeReference
discrepancy_name : str, NodeReference, optional
Only needed if model is an ElfiModel
output_names : list, optional
Additional outputs from the model to be included in the inference result, e.g.
corresponding summaries to the acquired samples
kwargs:
See ParameterInference
"""
model, discrepancy_name = self._resolve_model(model, discrepancy_name)
output_names = [discrepancy_name] + model.parameter_names + (output_names or [])
super(SMC, self).__init__(model, output_names, **kwargs)
self._prior = ModelPrior(self.model)
self.discrepancy_name = discrepancy_name
self.state['round'] = 0
self._populations = []
self._rejection = None
self._round_random_state = None
self._quantiles = None
[docs] def set_objective(self, n_samples, thresholds=None, quantiles=None):
"""Set objective for ABC-SMC inference.
Parameters
----------
n_samples : int
Number of samples to generate
thresholds : list, optional
List of thresholds for ABC-SMC
quantiles : list, optional
List of selection quantiles used to determine sample thresholds
"""
if thresholds is None and quantiles is None:
raise ValueError("Either thresholds or quantiles is required to run ABC-SMC.")
if thresholds is None:
rounds = len(quantiles) - 1
else:
rounds = len(thresholds) - 1
# Take previous iterations into account in case continued estimation
self.state['round'] = len(self._populations)
rounds = rounds + self.state['round']
if thresholds is None:
thresholds = np.full((rounds+1), None)
self._quantiles = np.concatenate((np.full((self.state['round']), None), quantiles))
else:
thresholds = np.concatenate((np.full((self.state['round']), None), thresholds))
self.objective.update(
dict(
n_samples=n_samples,
n_batches=self.max_parallel_batches,
round=rounds,
thresholds=thresholds))
self._init_new_round()
self._update_objective()
[docs] def update(self, batch, batch_index):
"""Update the inference state with a new batch.
Parameters
----------
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
batch_index : int
"""
super(SMC, self).update(batch, batch_index)
self._rejection.update(batch, batch_index)
if self._rejection.finished:
self.batches.cancel_pending()
if self.bar:
self.progress_bar.update_progressbar(self.progress_bar.scaling + 1,
self.progress_bar.scaling + 1)
if self.state['round'] < self.objective['round']:
self._populations.append(self._extract_population())
self.state['round'] += 1
self._init_new_round()
self._update_objective()
[docs] def prepare_new_batch(self, batch_index):
"""Prepare values for a new batch.
Parameters
----------
batch_index : int
next batch_index to be submitted
Returns
-------
batch : dict or None
Keys should match to node names in the model. These values will override any
default values or operations in those nodes.
"""
if self.state['round'] == 0:
# Use the actual prior
return
# Sample from the proposal, condition on actual prior
params = GMDistribution.rvs(*self._gm_params, size=self.batch_size,
prior_logpdf=self._prior.logpdf,
random_state=self._round_random_state)
batch = arr2d_to_batch(params, self.parameter_names)
return batch
def _init_new_round(self):
self._set_rejection_round(self.state['round'])
if self.state['round'] == 0 and self._quantiles is not None:
self._rejection.set_objective(
self.objective['n_samples'], quantile=self._quantiles[0])
else:
if self._quantiles is not None:
self._set_threshold()
self._rejection.set_objective(
self.objective['n_samples'], threshold=self.current_population_threshold)
def _set_rejection_round(self, round):
self._update_round_info(self.state['round'])
# Get a subseed for this round for ensuring consistent results for the round
seed = self.seed if round == 0 else get_sub_seed(self.seed, round)
self._round_random_state = np.random.RandomState(seed)
self._rejection = Rejection(
self.model,
discrepancy_name=self.discrepancy_name,
output_names=self.output_names,
batch_size=self.batch_size,
seed=seed,
max_parallel_batches=self.max_parallel_batches)
def _update_round_info(self, round):
if self.bar:
reinit_msg = 'ABC-SMC Round {0} / {1}'.format(
round + 1, self.objective['round'] + 1)
self.progress_bar.reinit_progressbar(
scaling=(self.state['n_batches']), reinit_msg=reinit_msg)
dashes = '-' * 16
logger.info('%s Starting round %d %s' % (dashes, round, dashes))
def _extract_population(self):
sample = self._rejection.extract_result()
# Append the sample object
sample.method_name = "Rejection within SMC-ABC"
means, w, cov = self._compute_weights_means_and_cov(sample)
sample.means = means
sample.weights = w
sample.meta['cov'] = cov
return sample
def _compute_weights_means_and_cov(self, pop):
params = np.column_stack(tuple([pop.outputs[p] for p in self.parameter_names]))
if self._populations:
q_logpdf = GMDistribution.logpdf(params, *self._gm_params)
p_logpdf = self._prior.logpdf(params)
w = np.exp(p_logpdf - q_logpdf)
else:
w = np.ones(pop.n_samples)
means = params.copy()
if np.count_nonzero(w) == 0:
raise RuntimeError("All sample weights are zero. If you are using a prior "
"with a bounded support, this may be caused by specifying "
"a too small sample size.")
# New covariance
cov = 2 * np.diag(weighted_var(params, w))
if not np.all(np.isfinite(cov)):
logger.warning("Could not estimate the sample covariance. This is often "
"caused by majority of the sample weights becoming zero."
"Falling back to using unit covariance.")
cov = np.diag(np.ones(params.shape[1]))
return means, w, cov
def _update_objective(self):
"""Update the objective n_batches."""
n_batches = sum([pop.n_batches for pop in self._populations])
self.objective['n_batches'] = n_batches + \
self._rejection.objective['n_batches']
def _set_threshold(self):
previous_population = self._populations[self.state['round']-1]
threshold = weighted_sample_quantile(
x=previous_population.discrepancies,
alpha=self._quantiles[self.state['round']],
weights=previous_population.weights)
logger.info('ABC-SMC: Selected threshold for next population %.3f' % (threshold))
self.objective['thresholds'][self.state['round']] = threshold
@property
def _gm_params(self):
sample = self._populations[-1]
return sample.means, sample.cov, sample.weights
@property
def current_population_threshold(self):
"""Return the threshold for current population."""
return self.objective['thresholds'][self.state['round']]
[docs]class AdaptiveDistanceSMC(SMC):
"""SMC-ABC sampler with adaptive threshold and distance function.
Notes
-----
Algorithm 5 in Prangle (2017)
References
----------
Prangle D (2017). Adapting the ABC Distance Function. Bayesian
Analysis 12(1):289-309, 2017.
https://projecteuclid.org/euclid.ba/1460641065
"""
def __init__(self, model, discrepancy_name=None, output_names=None, **kwargs):
"""Initialize the adaptive distance SMC-ABC sampler.
Parameters
----------
model : ElfiModel or NodeReference
discrepancy_name : str, NodeReference, optional
Only needed if model is an ElfiModel
output_names : list, optional
Additional outputs from the model to be included in the inference result, e.g.
corresponding summaries to the acquired samples
kwargs:
See ParameterInference
"""
model, discrepancy_name = self._resolve_model(model, discrepancy_name)
if not isinstance(model[discrepancy_name], AdaptiveDistance):
raise TypeError('This method requires an adaptive distance node.')
# Initialise adaptive distance node
model[discrepancy_name].init_state()
# Add summaries in additional outputs as these are needed to update the distance node
sums = [sumstat.name for sumstat in model[discrepancy_name].parents]
if output_names is None:
output_names = sums
else:
for k in sums:
if k not in output_names:
output_names.append(k)
super(AdaptiveDistanceSMC, self).__init__(model, discrepancy_name,
output_names=output_names, **kwargs)
[docs] def set_objective(self, n_samples, rounds, quantile=0.5):
"""Set objective for adaptive distance ABC-SMC inference.
Parameters
----------
n_samples : int
Number of samples to generate
rounds : int, optional
Number of populations to sample
quantile : float, optional
Selection quantile used to determine sample thresholds
"""
super(AdaptiveDistanceSMC, self).set_objective(ceil(n_samples/quantile),
quantiles=[1]*rounds)
self.population_size = n_samples
self.quantile = quantile
def _extract_population(self):
# Extract population and metadata based on rejection sample
rejection_sample = self._rejection.extract_result()
outputs = dict()
for k in self.output_names:
outputs[k] = rejection_sample.outputs[k][:self.population_size]
meta = rejection_sample.meta
meta['adaptive_distance_w'] = self.model[self.discrepancy_name].state['w'][-1]
meta['threshold'] = max(outputs[self.discrepancy_name])
meta['accept_rate'] = self.population_size/meta['n_sim']
method_name = "Rejection within adaptive distance SMC-ABC"
sample = Sample(method_name, outputs, self.parameter_names, **meta)
# Append the sample object
means, w, cov = self._compute_weights_means_and_cov(sample)
sample.means = means
sample.weights = w
sample.meta['cov'] = cov
return sample
def _extract_result_kwargs(self):
kwargs = super(AdaptiveDistanceSMC, self)._extract_result_kwargs()
kwargs['adaptive_distance_w'] = [pop.adaptive_distance_w for pop in self._populations]
return kwargs
def _set_threshold(self):
round = self.state['round']
self.objective['thresholds'][round] = self._populations[round-1].threshold
@property
def current_population_threshold(self):
"""Return the threshold for current population."""
return [np.inf] + [pop.threshold for pop in self._populations]
[docs]class AdaptiveThresholdSMC(SMC):
"""ABC-SMC sampler with adaptive threshold selection.
References
----------
Simola U, Cisewski-Kehe J, Gutmann M U, Corander J (2021). Adaptive
Approximate Bayesian Computation Tolerance Selection. Bayesian Analysis.
https://doi.org/10.1214/20-BA1211
"""
def __init__(self,
model,
discrepancy_name=None,
output_names=None,
initial_quantile=0.20,
q_threshold=0.99,
densratio_estimation=None,
**kwargs):
"""Initialize the adaptive threshold SMC-ABC sampler.
Parameters
----------
model : ElfiModel or NodeReference
discrepancy_name : str, NodeReference, optional
Only needed if model is an ElfiModel
output_names : list, optional
Additional outputs from the model to be included in the inference result, e.g.
corresponding summaries to the acquired samples
initial_quantile : float, optional
Initial selection quantile for the first round of adaptive-ABC-SMC
q_threshold : float, optional
Termination criteratia for adaptive-ABC-SMC
densratio_estimation : DensityRatioEstimation, optional
Density ratio estimation object defining parameters for KLIEP
kwargs:
See ParameterInference
"""
model, discrepancy_name = self._resolve_model(model, discrepancy_name)
output_names = [discrepancy_name] + model.parameter_names + (output_names or [])
super(SMC, self).__init__(model, output_names, **kwargs)
self._prior = ModelPrior(self.model)
self.discrepancy_name = discrepancy_name
self.state['round'] = 0
self._populations = []
self._rejection = None
self._round_random_state = None
self.q_threshold = q_threshold
self.initial_quantile = initial_quantile
self.densratio = densratio_estimation or DensityRatioEstimation(n=100,
epsilon=0.001,
max_iter=200,
abs_tol=0.01,
fold=5,
optimize=False)
[docs] def set_objective(self,
n_samples,
max_iter=10):
"""Set objective for ABC-SMC inference.
Parameters
----------
n_samples : int
Number of samples to generate
thresholds : list, optional
List of thresholds for ABC-SMC
max_iter : int, optional
Maximum number of iterations
"""
rounds = max_iter - 1
# Take previous iterations into account in case continued estimation
self.state['round'] = len(self._populations)
rounds = rounds + self.state['round']
# Initialise threshold selection and adaptive quantile
thresholds = np.full((rounds+1), None)
self._quantiles = np.full((rounds+1), None)
self._quantiles[0] = self.initial_quantile
self.objective.update(
dict(
n_samples=n_samples,
n_batches=self.max_parallel_batches,
round=rounds,
thresholds=thresholds))
self._init_new_round()
self._update_objective()
[docs] def update(self, batch, batch_index):
"""Update the inference state with a new batch.
Parameters
----------
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
batch_index : int
"""
super(SMC, self).update(batch, batch_index)
self._rejection.update(batch, batch_index)
if self._rejection.finished:
self.batches.cancel_pending()
if self.bar:
self.progress_bar.update_progressbar(self.progress_bar.scaling + 1,
self.progress_bar.scaling + 1)
self._new_population = self._extract_population()
if self.state['round'] < self.objective['round']:
self._set_adaptive_quantile()
if self._quantiles[self.state['round']+1] < self.q_threshold:
self._populations.append(self._new_population)
self.state['round'] += 1
self._init_new_round()
self._update_objective()
def _set_adaptive_quantile(self):
"""Set adaptively the new threshold for current population."""
logger.info("ABC-SMC: Adapting quantile threshold...")
sample_data_current = self._resolve_sample(backwards_index=0)
sample_data_previous = self._resolve_sample(backwards_index=-1)
if self.densratio.optimize:
sigma = list(10.0 ** np.arange(-1, 6))
else:
sigma = calculate_densratio_basis_sigma(sample_data_current['sigma_max'],
sample_data_previous['sigma_max'])
self.densratio.fit(x=sample_data_current['samples'],
y=sample_data_previous['samples'],
weights_x=sample_data_current['weights'],
weights_y=sample_data_previous['weights'],
sigma=sigma)
max_value = self.densratio.max_ratio()
max_value = 1.0 if max_value < 1.0 else max_value
self._quantiles[self.state['round']+1] = max(1 / max_value, 0.05)
logger.info('ABC-SMC: Estimated maximum density ratio %.5f' % (1 / max_value))
def _resolve_sample(self, backwards_index):
"""Get properties of the samples used in ratio estimation."""
if self.state['round'] + backwards_index < 0:
return self._densityratio_initial_sample()
elif backwards_index == 0:
sample = self._new_population
else:
sample = self._populations[backwards_index]
weights = sample.weights
samples = sample.samples_array
sample_sigma = np.sqrt(np.diag(sample.cov))
sigma_max = np.min(sample_sigma)
sample_data = dict(samples=samples, weights=weights, sigma_max=sigma_max)
return sample_data
def _densityratio_initial_sample(self):
n_samples = self._new_population.weights.shape[0]
samples = self._prior.rvs(size=n_samples, random_state=self._round_random_state)
weights = np.ones(n_samples)
sample_cov = np.atleast_2d(np.cov(samples.reshape(n_samples, -1), rowvar=False))
sigma_max = np.min(np.sqrt(np.diag(sample_cov)))
return dict(samples=samples,
weights=weights,
sigma_max=sigma_max)