Source code for elfi.methods.inference.parameter_inference

"""This module contains base class for inference methods."""

import logging
from math import ceil

import elfi.client
from elfi.model.elfi_model import ComputationContext, ElfiModel, NodeReference
from elfi.visualization.visualization import ProgressBar

logger = logging.getLogger(__name__)


# TODO: refactor the plotting functions

[docs]class ParameterInference: """A base class for parameter inference methods. Attributes ---------- model : elfi.ElfiModel The ELFI graph used by the algorithm output_names : list Names of the nodes whose outputs are included in the batches client : elfi.client.ClientBase The batches are computed in the client max_parallel_batches : int state : dict Stores any changing data related to achieving the objective. Must include a key ``n_batches`` for determining when the inference is finished. objective : dict Holds the data for the algorithm to internally determine how many batches are still needed. You must have a key ``n_batches`` here. By default the algorithm finished when the ``n_batches`` in the state dictionary is equal or greater to the corresponding objective value. batches : elfi.client.BatchHandler Helper class for submitting batches to the client and keeping track of their indexes. pool : elfi.store.OutputPool Pool object for storing and reusing node outputs. """ def __init__(self, model, output_names, batch_size=1, seed=None, pool=None, max_parallel_batches=None): """Construct the inference algorithm object. If you are implementing your own algorithm do not forget to call `super`. Parameters ---------- model : ElfiModel Model to perform the inference with. output_names : list Names of the nodes whose outputs will be requested from the ELFI graph. batch_size : int, optional The number of parameter evaluations in each pass through the ELFI graph. When using a vectorized simulator, using a suitably large batch_size can provide a significant performance boost. seed : int, optional Seed for the data generation from the ElfiModel pool : OutputPool, optional OutputPool both stores and provides precomputed values for batches. max_parallel_batches : int, optional Maximum number of batches allowed to be in computation at the same time. Defaults to number of cores in the client """ model = model.model if isinstance(model, NodeReference) else model if not model.parameter_names: raise ValueError('Model {} defines no parameters'.format(model)) self.model = model.copy() self.output_names = self._check_outputs(output_names) self.client = elfi.client.get_client() # Prepare the computation_context context = ComputationContext( batch_size=batch_size, seed=seed, pool=pool) self.batches = elfi.client.BatchHandler( self.model, context=context, output_names=output_names, client=self.client) self.computation_context = context self.max_parallel_batches = max_parallel_batches or self.client.num_cores if self.max_parallel_batches <= 0: msg = 'Value for max_parallel_batches ({}) must be at least one.'.format( self.max_parallel_batches) if self.client.num_cores == 0: msg += ' Client has currently no workers available. Please make sure ' \ 'the cluster has fully started or set the max_parallel_batches ' \ 'parameter by hand.' raise ValueError(msg) # State and objective should contain all information needed to continue the # inference after an iteration. self.state = dict(n_sim=0, n_batches=0) self.objective = dict() self.progress_bar = ProgressBar(prefix='Progress', suffix='Complete', decimals=1, length=50, fill='=') @property def pool(self): """Return the output pool of the inference.""" return self.computation_context.pool @property def seed(self): """Return the seed of the inference.""" return self.computation_context.seed @property def parameter_names(self): """Return the parameters to be inferred.""" return self.model.parameter_names @property def batch_size(self): """Return the current batch_size.""" return self.computation_context.batch_size
[docs] def set_objective(self, *args, **kwargs): """Set the objective of the inference. This method sets the objective of the inference (values typically stored in the `self.objective` dict). Returns ------- None """ raise NotImplementedError
[docs] def extract_result(self): """Prepare the result from the current state of the inference. ELFI calls this method in the end of the inference to return the result. Returns ------- result : elfi.methods.result.Result """ raise NotImplementedError
[docs] def update(self, batch, batch_index): """Update the inference state with a new batch. ELFI calls this method when a new batch has been computed and the state of the inference should be updated with it. It is also possible to bypass ELFI and call this directly to update the inference. Parameters ---------- batch : dict dict with `self.outputs` as keys and the corresponding outputs for the batch as values batch_index : int Returns ------- None """ self.state['n_batches'] += 1 self.state['n_sim'] += self.batch_size
[docs] def prepare_new_batch(self, batch_index): """Prepare values for a new batch. ELFI calls this method before submitting a new batch with an increasing index `batch_index`. This is an optional method to override. Use this if you have a need do do preparations, e.g. in Bayesian optimization algorithm, the next acquisition points would be acquired here. If you need provide values for certain nodes, you can do so by constructing a batch dictionary and returning it. See e.g. BayesianOptimization for an example. Parameters ---------- batch_index : int next batch_index to be submitted Returns ------- batch : dict or None Keys should match to node names in the model. These values will override any default values or operations in those nodes. """ pass
[docs] def plot_state(self, **kwargs): """Plot the current state of the algorithm. Parameters ---------- axes : matplotlib.axes.Axes (optional) figure : matplotlib.figure.Figure (optional) xlim x-axis limits ylim y-axis limits interactive : bool (default False) If true, uses IPython.display to update the cell figure close Close figure in the end of plotting. Used in the end of interactive mode. Returns ------- None """ raise NotImplementedError
[docs] def infer(self, *args, vis=None, bar=True, **kwargs): """Set the objective and start the iterate loop until the inference is finished. See the other arguments from the `set_objective` method. Parameters ---------- vis : dict, optional Plotting options. More info in self.plot_state method bar : bool, optional Flag to remove (False) or keep (True) the progress bar from/in output. Returns ------- result : Sample """ vis_opt = vis if isinstance(vis, dict) else {} self.set_objective(*args, **kwargs) while not self.finished: self.iterate() if vis: self.plot_state(interactive=True, **vis_opt) if bar: self.progress_bar.update_progressbar(self.state['n_batches'], self._objective_n_batches) self.batches.cancel_pending() if vis: self.plot_state(close=True, **vis_opt) return self.extract_result()
[docs] def iterate(self): """Advance the inference by one iteration. This is a way to manually progress the inference. One iteration consists of waiting and processing the result of the next batch in succession and possibly submitting new batches. Notes ----- If the next batch is ready, it will be processed immediately and no new batches are submitted. New batches are submitted only while waiting for the next one to complete. There will never be more batches submitted in parallel than the `max_parallel_batches` setting allows. Returns ------- None """ # Submit new batches if allowed while self._allow_submit(self.batches.next_index): next_batch = self.prepare_new_batch(self.batches.next_index) logger.debug("Submitting batch %d" % self.batches.next_index) self.batches.submit(next_batch) # Handle the next ready batch in succession batch, batch_index = self.batches.wait_next() logger.debug('Received batch %d' % batch_index) self.update(batch, batch_index)
@property def finished(self): """Check whether objective of n_batches have been reached.""" return self._objective_n_batches <= self.state['n_batches'] def _allow_submit(self, batch_index): return (self.max_parallel_batches > self.batches.num_pending and self._has_batches_to_submit and (not self.batches.has_ready())) @property def _has_batches_to_submit(self): return self._objective_n_batches > self.state['n_batches'] + self.batches.num_pending @property def _objective_n_batches(self): """Check that n_batches can be computed from the objective.""" if 'n_batches' in self.objective: n_batches = self.objective['n_batches'] elif 'n_sim' in self.objective: n_batches = ceil(self.objective['n_sim'] / self.batch_size) else: raise ValueError( 'Objective must define either `n_batches` or `n_sim`.') return n_batches def _extract_result_kwargs(self): """Extract common arguments for the ParameterInferenceResult object.""" return { 'method_name': self.__class__.__name__, 'parameter_names': self.parameter_names, 'seed': self.seed, 'n_sim': self.state['n_sim'], 'n_batches': self.state['n_batches'] } @staticmethod def _resolve_model(model, target, default_reference_class=NodeReference): if isinstance(model, ElfiModel) and target is None: raise NotImplementedError( "Please specify the target node of the inference method") if isinstance(model, NodeReference): target = model model = target.model if isinstance(target, str): target = model[target] if not isinstance(target, default_reference_class): raise ValueError('Unknown target node class') return model, target.name def _check_outputs(self, output_names): """Filter out duplicates and check that corresponding nodes exist. Preserves the order. """ output_names = output_names or [] checked_names = [] seen = set() for name in output_names: if isinstance(name, NodeReference): name = name.name if name in seen: continue elif not isinstance(name, str): raise ValueError( 'All output names must be strings, object {} was given'.format(name)) elif not self.model.has_node(name): raise ValueError( 'Node {} output was requested, but it is not in the model.') seen.add(name) checked_names.append(name) return checked_names