Source code for elfi.store

"""This module contains implementations for storing simulated values for later use."""

import io
import logging
import os
import pickle
import shutil

import numpy as np
import numpy.lib.format as npformat

logger = logging.getLogger(__name__)

_default_prefix = 'pools'


[docs]class OutputPool: """Store node outputs to dictionary-like stores. The default store is a Python dictionary. Notes ----- Saving the store requires that all the stores are pickleable. Arbitrary objects that support simple array indexing can be used as stores by using the `elfi.store.ArrayObjectStore` class. See the `elfi.store.StoreBase` interfaces if you wish to implement your own ELFI compatible store. Basically any object that fulfills the Pythons dictionary api will work as a store in the pool. """ _pkl_name = '_outputpool.pkl' def __init__(self, outputs=None, name=None, prefix=None): """Initialize OutputPool. Depending on the algorithm, some of these values may be reused after making some changes to `ElfiModel` thus speeding up the inference significantly. For instance, if all the simulations are stored in Rejection sampling, one can change the summaries and distances without having to rerun the simulator. Parameters ---------- outputs : list, dict, optional List of node names which to store or a dictionary with existing stores. The stores are created on demand. name : str, optional Name of the pool. Used to open a saved pool from disk. prefix : str, optional Path to directory under which `elfi.ArrayPool` will place its folder. Default is a relative path ./pools. Returns ------- instance : OutputPool """ if outputs is None: stores = {} elif isinstance(outputs, dict): stores = outputs else: stores = dict.fromkeys(outputs) self.stores = stores # Context information self.batch_size = None self.seed = None self.name = name self.prefix = prefix or _default_prefix if self.path and os.path.exists(self.path): raise ValueError("A pool with this name already exists in {}. You can use " "OutputPool.open() to open it.".format(self.prefix)) @property def output_names(self): """Return a list of stored names.""" return list(self.stores.keys()) @property def has_context(self): """Check if current pool has context information.""" return self.seed is not None and self.batch_size is not None
[docs] def set_context(self, context): """Set the context of the pool. The pool needs to know the batch_size and the seed. Notes ----- Also sets the name of the pool if not set already. Parameters ---------- context : elfi.ComputationContext """ if self.has_context: raise ValueError('Context is already set') self.batch_size = context.batch_size self.seed = context.seed if self.name is None: self.name = "{}_{}".format(self.__class__.__name__.lower(), self.seed)
[docs] def get_batch(self, batch_index, output_names=None): """Return a batch from the stores of the pool. Parameters ---------- batch_index : int output_names : list which outputs to include to the batch Returns ------- batch : dict """ output_names = output_names or self.output_names batch = dict() for output in output_names: store = self.stores[output] if store is None: continue if batch_index in store: batch[output] = store[batch_index] return batch
[docs] def add_batch(self, batch, batch_index): """Add the outputs from the batch to their stores.""" for node, values in batch.items(): if node not in self.stores: continue store = self._get_store_for(node) # Do not add again. The output should be the same. if batch_index in store: continue store[batch_index] = values
[docs] def remove_batch(self, batch_index): """Remove the batch from all stores.""" for store in self.stores.values(): if batch_index in store: del store[batch_index]
[docs] def has_store(self, node): """Check if `node` is in stores.""" return node in self.stores
[docs] def get_store(self, node): """Return the store for `node`.""" return self.stores[node]
[docs] def add_store(self, node, store=None): """Add a store object for the node. Parameters ---------- node : str store : dict, StoreBase, optional """ if node in self.stores and self.stores[node] is not None: raise ValueError("Store for '{}' already exists".format(node)) store = store if store is not None else self._make_store_for(node) self.stores[node] = store
[docs] def remove_store(self, node): """Remove and return a store from the pool. Parameters ---------- node : str Returns ------- store The removed store """ store = self.stores.pop(node) return store
def _get_store_for(self, node): """Get or make a store.""" if self.stores[node] is None: self.stores[node] = self._make_store_for(node) return self.stores[node] def _make_store_for(self, node): """Make a default store for a node. All the default stores will be created through this method. """ return {} def __len__(self): """Return the largest batch index in any of the stores.""" largest = 0 for output, store in self.stores.items(): if store is None: continue largest = max(largest, len(store)) return largest def __getitem__(self, batch_index): """Return the batch.""" return self.get_batch(batch_index) def __setitem__(self, batch_index, batch): """Add `batch` into location `batch_index`.""" return self.add_batch(batch, batch_index) def __contains__(self, batch_index): """Check if the pool contains `batch_index`.""" return len(self) > batch_index
[docs] def clear(self): """Remove all data from the stores.""" for store in self.stores.values(): store.clear()
[docs] def save(self): """Save the pool to disk. This will use pickle to store the pool under self.path. """ if not self.has_context: raise ValueError("Pool context is not set, cannot save. Please see the " "set_context method.") os.makedirs(self.path, exist_ok=True) # Change the working directory so that relative paths to the pool data folder can # be reliably used. This allows moving and renaming of the folder. cwd = os.getcwd() os.chdir(self.path) # Pickle the stores separately for node, store in self.stores.items(): filename = node + '.pkl' try: pickle.dump(store, open(filename, 'wb')) except BaseException: raise IOError('Failed to pickle the store for node {}, please check that ' 'it is pickleable or remove it before saving.'.format(node)) os.chdir(cwd) # Save the pool itself with stores replaced with Nones stores = self.stores self.stores = dict.fromkeys(stores.keys()) filename = os.path.join(self.path, self._pkl_name) pickle.dump(self, open(filename, "wb")) # Restore the original to the object self.stores = stores
[docs] def close(self): """Save and close the stores that support it. The pool will not be usable afterwards. """ self.save() for store in self.stores.values(): if hasattr(store, 'close'): store.close()
[docs] def flush(self): """Flush all data from the stores. If the store does not support flushing, do nothing. """ for store in self.stores.values(): if hasattr(store, 'flush'): store.flush()
[docs] def delete(self): """Remove all persisted data from disk.""" for store in self.stores.values(): if hasattr(store, 'close'): store.close() if self.path is None: return elif not os.path.exists(self.path): return shutil.rmtree(self.path)
[docs] @classmethod def open(cls, name, prefix=None): """Open a closed or saved ArrayPool from disk. Parameters ---------- name : str prefix : str, optional Returns ------- ArrayPool """ prefix = prefix or _default_prefix path = cls._make_path(name, prefix) filename = os.path.join(path, cls._pkl_name) pool = pickle.load(open(filename, "rb")) # Load the stores. Change the working directory temporarily so that pickled stores # can find their data dependencies even if the folder has been renamed. cwd = os.getcwd() os.chdir(path) for node in list(pool.stores.keys()): filename = node + '.pkl' try: store = pickle.load(open(filename, 'rb')) except Exception as e: logger.warning('Failed to load the store for node {}. Reason: {}' .format(node, str(e))) del pool.stores[node] continue pool.stores[node] = store os.chdir(cwd) # Update the name and prefix in case the pool folder was moved pool.name = name pool.prefix = prefix return pool
@classmethod def _make_path(cls, name, prefix): return os.path.join(prefix, name) @property def path(self): """Return the path to the pool.""" if self.name is None: return None return self._make_path(self.name, self.prefix)
[docs]class ArrayPool(OutputPool): """OutputPool that uses binary .npy files as default stores. The default store medium for output data is a NumPy binary `.npy` file for NumPy array data. You can however also add other types of stores as well. Notes ----- The default store is implemented in elfi.store.NpyStore that uses NpyArrays as stores. The NpyArray is a wrapper over NumPy .npy binary file for array data and supports appending the .npy file. It uses the .npy format 2.0 files. """ def _make_store_for(self, node): if not self.has_context: raise ValueError('ArrayPool has no context set') # Make the directory for the array pools os.makedirs(self.path, exist_ok=True) filename = os.path.join(self.path, node) return NpyStore(filename, self.batch_size)
class StoreBase: """Base class for output stores for the pools. Stores store the outputs of a single node in ElfiModel. This is a subset of the Python dictionary api. Notes ----- Any dictionary like object will work directly as an ELFI store. """ def __getitem__(self, batch_index): """Return a batch from location `batch_index`.""" raise NotImplementedError def __setitem__(self, batch_index, data): """Set array to `data` at location `batch_index`.""" raise NotImplementedError def __delitem__(self, batch_index): """Delete data from location `batch_index`.""" raise NotImplementedError def __contains__(self, batch_index): """Check if array contains `batch_index`.""" raise NotImplementedError def __len__(self): """Return the number of batches in the store.""" raise NotImplementedError def clear(self): """Remove all batches from the store.""" raise NotImplementedError def close(self): """Close the store. Optional method. Useful for closing i.e. file streams. """ pass def flush(self): """Flush the store. Optional to implement. """ pass # TODO: add mask for missing items. It should replace the use of `n_batches`. # This should make it possible to also append further than directly to the end # of current n_batches index. class ArrayStore(StoreBase): """Convert any array object to ELFI store to be used within a pool. This class is intended to make it easy to use objects that support array indexing as outputs stores for nodes. Attributes ---------- array : array-like The array that the batches are stored to batch_size : int n_batches : int How many batches are available from the underlying array. """ def __init__(self, array, batch_size, n_batches=-1): """Initialize ArrayStore. Parameters ---------- array Any array like object supporting Python list indexing batch_size : int Size of a batch of data n_batches : int, optional How many batches should be made available from the array. Default is -1 meaning all available batches. """ if n_batches == -1: if len(array) % batch_size != 0: logger.warning("The array length is not divisible by the batch size.") n_batches = len(array) // batch_size self.array = array self.batch_size = batch_size self.n_batches = n_batches def __getitem__(self, batch_index): """Return a batch from location `batch_index`.""" sl = self._to_slice(batch_index) return self.array[sl] def __setitem__(self, batch_index, data): """Set array to `data` at location `batch_index`.""" if batch_index > self.n_batches: raise IndexError("Appending further than to the end of the store array is " "currently not supported.") sl = self._to_slice(batch_index) if sl.stop > len(self.array): raise IndexError("There is not enough space left in the store array.") self.array[sl] = data if batch_index == self.n_batches: self.n_batches += 1 def __contains__(self, batch_index): """Check if array contains `batch_index`.""" return batch_index < self.n_batches def __delitem__(self, batch_index): """Delete data from location `batch_index`.""" if batch_index not in self: raise IndexError("Cannot remove, batch index {} is not in the array" .format(batch_index)) elif batch_index != self.n_batches - 1: raise IndexError("Removing batches from the middle of the store array is " "currently not supported.") # Move the n_batches index down if batch_index == self.n_batches - 1: self.n_batches -= 1 def __len__(self): """Return the number of batches in store.""" return self.n_batches def _to_slice(self, batch_index): """Return a slice object that covers the batch at `batch_index`.""" a = self.batch_size * batch_index return slice(a, a + self.batch_size) def clear(self): """Clear array from store.""" if hasattr(self.array, 'clear'): self.array.clear() self.n_batches = 0 def flush(self): """Flush any changes in memory to array.""" if hasattr(self.array, 'flush'): self.array.flush() def close(self): """Close array.""" if hasattr(self.array, 'close'): self.array.close() def __del__(self): """Close array.""" self.close() class NpyStore(ArrayStore): """Store data to binary .npy files. Uses the NpyArray objects as an array store. """ def __init__(self, file, batch_size, n_batches=-1): """Initialize NpyStore. Parameters ---------- file : NpyArray or str NpyArray object or path to the .npy file batch_size n_batches : int, optional How many batches to make available from the file. Default -1 indicates that all available batches. """ array = file if isinstance(file, NpyArray) else NpyArray(file) super(NpyStore, self).__init__(array, batch_size, n_batches) def __setitem__(self, batch_index, data): """Set array to `data` at location `batch_index`.""" sl = self._to_slice(batch_index) # NpyArray supports appending if batch_index == self.n_batches and sl.start == len(self.array): self.array.append(data) self.n_batches += 1 return super(NpyStore, self).__setitem__(batch_index, data) def __delitem__(self, batch_index): """Delete data from location `batch_index`.""" super(NpyStore, self).__delitem__(batch_index) sl = self._to_slice(batch_index) self.array.truncate(sl.start) def delete(self): """Delete array.""" self.array.delete() class NpyArray: """Extension to NumPy's .npy format. The NpyArray is a wrapper over NumPy .npy binary file for array data and supports appending the .npy file. Notes ----- - Supports only binary files. - Supports only .npy version 2.0 - See numpy.lib.npformat for documentation of the .npy format """ MAX_SHAPE_LEN = 2**64 # Version 2.0 header prefix length HEADER_DATA_OFFSET = 12 HEADER_DATA_SIZE_OFFSET = 8 def __init__(self, filename, array=None, truncate=False): """Initialize NpyArray. Parameters ---------- filename : str File name array : ndarray, optional Initial array truncate : bool Whether to truncate the file or not """ self.header_length = None self.itemsize = None # Header data fields self.shape = None self.fortran_order = False self.dtype = None # The header bytes must be prepared in advance, because there is an import in # `numpy.lib.format._write_array_header` (1.11.3) that fails if the program is # being closed on exception and would corrupt the .npy file. self._header_bytes_to_write = None if filename[-4:] != '.npy': filename += '.npy' self.filename = filename if array is not None: truncate = True self.fs = None if truncate is False and os.path.exists(self.filename): self.fs = open(self.filename, 'r+b') self._init_from_file_header() else: self.fs = open(self.filename, 'w+b') # Numpy memmap for the file array data self._memmap = None if array is not None: self.append(array) self.flush() def __getitem__(self, sl): """Return a slice `sl` of data.""" return self.memmap[sl] def __setitem__(self, sl, value): """Set data at slice `sl` to `value`.""" self.memmap[sl] = value def __len__(self): """Return the length of array.""" return self.shape[0] if self.shape else 0 @property def size(self): """Return the number of items in the array.""" return np.prod(self.shape) def append(self, array): """Append data from `array` to self.""" if self.closed: raise ValueError('Array is not opened.') if not self.initialized: self.init_from_array(array) if array.shape[1:] != self.shape[1:]: raise ValueError("Appended array is of different shape.") elif array.dtype != self.dtype: raise ValueError("Appended array is of different dtype.") # Append new data pos = self.header_length + self.size * self.itemsize self.fs.seek(pos) self.fs.write(array.tobytes('C')) self.shape = (self.shape[0] + len(array), ) + self.shape[1:] # Only prepare the header bytes, need to be flushed to take effect self._prepare_header_data() # Invalidate the memmap self._memmap = None @property def memmap(self): """Return a NumPy memory map to the array data.""" if not self.initialized: raise IndexError("NpyArray is not initialized") if self._memmap is None: order = 'F' if self.fortran_order else 'C' self._memmap = np.memmap(self.fs, dtype=self.dtype, shape=self.shape, offset=self.header_length, order=order) return self._memmap def _init_from_file_header(self): """Initialize the object from an existing file.""" self.fs.seek(self.HEADER_DATA_SIZE_OFFSET) try: self.shape, fortran_order, self.dtype = \ npformat.read_array_header_2_0(self.fs) except ValueError: raise ValueError('Npy file {} header is not 2.0 format. You can make the ' 'conversion using elfi.store.NpyFile by passing the ' 'preloaded array as an argument.'.format(self.filename)) self.header_length = self.fs.tell() if fortran_order: raise ValueError('Column major (Fortran-style) files are not supported. Please' 'translate if first to row major (C-style).') # Determine itemsize shape = (0, ) + self.shape[1:] self.itemsize = np.empty(shape=shape, dtype=self.dtype).itemsize def init_from_array(self, array): """Initialize the object from an array. Sets the the header_length so large that it is possible to append to the array. Returns ------- h_bytes : io.BytesIO Contains the oversized header bytes """ if self.initialized: raise ValueError("The array has been initialized already!") self.shape = (0, ) + array.shape[1:] self.dtype = array.dtype self.itemsize = array.itemsize # Read header data from array and set modify it to be large for the length # 1_0 is the same for 2_0 d = npformat.header_data_from_array_1_0(array) d['shape'] = (self.MAX_SHAPE_LEN, ) + d['shape'][1:] d['fortran_order'] = False # Write a prefix for a very long array to make it large enough for appending new # data h_bytes = io.BytesIO() npformat.write_array_header_2_0(h_bytes, d) self.header_length = h_bytes.tell() # Write header prefix to file self.fs.seek(0) h_bytes.seek(0) self.fs.write(h_bytes.read(self.HEADER_DATA_OFFSET)) # Write header data for the zero length to make it a valid file self._prepare_header_data() self._write_header_data() def truncate(self, length=0): """Truncate the array to the specified length. Parameters ---------- length : int Length (=`shape[0]`) of the array to truncate to. Default 0. """ if not self.initialized: raise ValueError('The array must be initialized before it can be truncated. ' 'Please see init_from_array.') if self.closed: raise ValueError('The array has been closed.') # Reset length self.shape = (length, ) + self.shape[1:] self._prepare_header_data() self.fs.seek(self.header_length + self.size * self.itemsize) self.fs.truncate() # Invalidate the memmap self._memmap = None def close(self): """Close the file.""" if self.initialized: self._write_header_data() self.fs.close() # Invalidate the memmap self._memmap = None def clear(self): """Truncate the array to 0.""" self.truncate(0) def delete(self): """Remove the file and invalidate this array.""" if self.deleted: return name = self.fs.name self.close() os.remove(name) self.fs = None self.header_length = None # Invalidate the memmap self._memmap = None def flush(self): """Flush any changes in memory to array.""" self._write_header_data() self.fs.flush() def __del__(self): """Close the array.""" self.close() def _prepare_header_data(self): # Make header data d = { 'shape': self.shape, 'fortran_order': self.fortran_order, 'descr': npformat.dtype_to_descr(self.dtype) } h_bytes = io.BytesIO() npformat.write_array_header_2_0(h_bytes, d) # Pad the end of the header fill_len = self.header_length - h_bytes.tell() if fill_len < 0: raise OverflowError( "File {} cannot be appended. The header is too short.".format(self.filename)) elif fill_len > 0: h_bytes.write(b'\x20' * fill_len) h_bytes.seek(0) self._header_bytes_to_write = h_bytes.read() def _write_header_data(self): if not self._header_bytes_to_write: return # Rewrite header data self.fs.seek(self.HEADER_DATA_OFFSET) h_bytes = self._header_bytes_to_write[self.HEADER_DATA_OFFSET:] self.fs.write(h_bytes) # Flag bytes off as they are now written self._header_bytes_to_write = None @property def deleted(self): """Check whether file has been deleted.""" return self.fs is None @property def closed(self): """Check if file has been deleted or closed.""" return self.deleted or self.fs.closed @property def initialized(self): """Check if file is open.""" return (not self.closed) and (self.header_length is not None) def __getstate__(self): """Return a dictionary with a key `filename`.""" if not self.fs.closed: self.flush() return {'filename': self.filename} def __setstate__(self, state): """Initialize with `filename` from dictionary `state`.""" filename = state.pop('filename') basename = os.path.basename(filename) if os.path.exists(filename): self.__init__(filename) elif os.path.exists(basename): self.__init__(basename) else: self.fs = None raise FileNotFoundError('Could not find the file {}'.format(filename))