"""Implementation of flexible MCMC sampling for lymphatic progression models.
This module provides both helpful functions for programmatically building and running
sampling pipelines, as well a CLI interface for th most common sampling use cases.
The core is the :py:func:`run_sampling` function. It has a flexible interface and
built-in convergence detection, as well as bookkeeping for monitoring and resuming
interrupted sampling runs. It can be used both during the burn-in phase and the actual
sampling phase.
.. warning::
We strongly recommend to set the CLI's ``--cores`` argument to ``None`` (or ``null``
in the YAML config file) if you are on MacOS or Windows. This is because we haven't
yet figured out how we can safely and efficiently use the ``multiprocess(ing)``
library on these two platforms.
"""
from __future__ import annotations
import os
import sys
from typing import Any
from loguru import logger
from lyscripts.cli import assemble_main
try:
import multiprocess as mp
except ModuleNotFoundError:
import multiprocessing as mp
if sys.platform == "darwin":
logger.warning("Detected MacOS. Setting multiprocess(ing) start method to 'fork'.")
mp.set_start_method("fork")
from pathlib import Path
import emcee
import numpy as np
import pandas as pd
from lydata.utils import ModalityConfig
from lymph.types import ParamsType
from pydantic import BaseModel, Field
from rich.progress import Progress, ProgressColumn, Task, TimeElapsedColumn
from rich.text import Text
from lyscripts.configs import (
BaseCLI,
DataConfig,
DistributionConfig,
GraphConfig,
ModelConfig,
SamplingConfig,
add_distributions,
add_modalities,
construct_model,
)
from lyscripts.utils import console, get_hdf5_backend
[docs]
class CompletedItersColumn(ProgressColumn):
"""A column that displays the completed number of iterations."""
def __init__(self, table_column=None, it: int = 0):
"""Initialize the column with number of previous iterations."""
super().__init__(table_column)
self.it = it
[docs]
def render(self, task: Task) -> Text:
"""Render total iterations."""
if task.completed is None:
return Text("? it", style="progress.data.steps")
return Text(f"{task.completed + self.it} it", style="progress.data.steps")
[docs]
class ItersPerSecondColumn(ProgressColumn):
"""A column that displays the number of iterations per second."""
[docs]
def render(self, task: Task) -> Text:
"""Render iterations per second."""
speed = task.finished_speed or task.speed
if speed is None:
return Text("? it/s", style="progress.data.speed")
return Text(f"{speed:.2f} it/s", style="progress.data.speed")
[docs]
class AcorTime(BaseModel, validate_assignment=True):
"""Storage for old and new autocorrelation times."""
old: float
new: float
[docs]
def update(self, new: float) -> None:
"""Update the autocorrelation time."""
self.old = self.new
self.new = new
@property
def relative_diff(self) -> float:
"""Get the relative difference between new and old autocorrelation time."""
return np.abs(self.new - self.old) / self.new
[docs]
class NumAccepted(BaseModel, validate_assignment=True):
"""Storage for old and new number of accepted proposals."""
old: int
new: int
[docs]
def update(self, new: int) -> None:
"""Update the number of accepted proposals."""
self.old = self.new
self.new = new
@property
def newly_accepted(self) -> int:
"""Get the number of newly accepted proposals."""
return self.new - self.old
MODEL = None
[docs]
def log_prob_fn(theta: ParamsType, inverse_temp: float = 1.0) -> tuple[float, float]:
"""Compute log-prob using global variables because of pickling.
An inverse temperature ``inverse_temp`` can be provided for thermodynamic
integration.
"""
llh = MODEL.likelihood(given_params=theta)
if np.isinf(llh): # to prevent the case of 0 * inf = NaN
return -np.inf, -np.inf
return inverse_temp * llh, llh
[docs]
def ensure_initial_state(sampler: emcee.EnsembleSampler) -> np.ndarray:
"""Try to extract a starting state from a ``sampler``.
Create a random starting state if no one was found.
"""
try:
state = sampler.backend.get_last_sample()
logger.info(
f"Resuming from {sampler.backend.filename} with {sampler.iteration} "
"stored iterations.",
)
except AttributeError:
state = np.random.uniform(size=(sampler.nwalkers, sampler.ndim)) # noqa: NPY002
logger.debug(f"No stored samples found. Starting from random state {state}.")
return state
[docs]
def ensure_history_table(file: Path | None) -> pd.DataFrame:
"""Return the history table from a file or an empty DataFrame.
It will try to load a history at the given ``file`` location, but with a ``.tmp``
extension. This is the expected name and location of a history file that was
stored during an interrupted sampling run.
If no file is found, an empty DataFrame is returned.
"""
if file is None or not file.with_suffix(".tmp").exists():
return pd.DataFrame(
columns=[
"steps",
"acor_times",
"accept_fracs",
"max_log_probs",
],
).set_index("steps")
return pd.read_csv(file.with_suffix(".tmp"), index_col="steps")
[docs]
def update_history_table(
history: pd.DataFrame,
history_file: Path | None,
iteration: int,
acor_time: float,
accepted_frac: float,
max_log_prob: float,
) -> pd.DataFrame:
"""Update the history table with the current iteration's information."""
history.loc[iteration] = [acor_time, accepted_frac, max_log_prob]
logger.debug(history.iloc[-1].to_dict())
if history_file is not None:
history.to_csv(history_file.with_suffix(".tmp"))
return history
[docs]
def is_converged(
iteration: int,
acor_time: AcorTime,
trust_factor: float,
relative_thresh: float,
) -> bool:
"""Check if the chain has converged based on the autocorrelation time.
The criterion is based on the relative change of the autocorrelation time and
whether the autocorrelation extimate can be trusted. Essentially, we only trust
the estimate if it is smaller than ``trust_factor`` times the current ``iteration``.
More details can be found in the `emcee documentation`_.
.. _emcee documentation: https://emcee.readthedocs.io/en/stable/tutorials/autocorr/
"""
return (
acor_time.new * trust_factor < iteration
and acor_time.relative_diff < relative_thresh
)
def _get_columns(it: int = 0) -> list[ProgressColumn]:
"""Get the default progress columns for the MCMC sampling."""
return [
*Progress.get_default_columns(),
ItersPerSecondColumn(),
CompletedItersColumn(it=it),
TimeElapsedColumn(),
]
[docs]
def run_sampling(
sampler: emcee.EnsembleSampler,
initial_state: np.ndarray | None = None,
num_steps: int | None = None,
thin_by: int = 1,
check_interval: int = 100,
trust_factor: float = 50.0,
relative_thresh: float = 0.05,
history_file: Path | None = None,
reset_backend: bool = False,
description: str = "Burn-in phase",
) -> None:
"""Run MCMC sampling.
This will run the ``sampler`` either for ``num_steps`` steps or - if it set to
``None`` - until convergence. Convergence is determined once within a
``check_interval`` of steps by the :py:func:`is_converged` function. The
convergence criterion is based on a trustworthy estimate of the autocorrelation
time. This is elaborated in the `emcee documentation`_.
Some bookkeeping parameters may be stored in a ``history_file``. During sampling,
the history is stored in a temporary file with the suffix ``.tmp``. If the sampling
is interrupted, the history and the last state of the ``sampler`` can be recovered
and the sampling can be continued.
One may choose to ``reset_backend``, e.g. in case the previous sampling was run
until convergence and now one wants to store a length of the converged chain. This
may also be thinned by a factor of ``thin_by`` (directly passed to the
:py:class:`emcee.EnsembleSampler` class).
.. _emcee documentation: https://emcee.readthedocs.io/en/stable/tutorials/autocorr/
"""
state = initial_state or ensure_initial_state(sampler)
history = ensure_history_table(history_file)
if reset_backend:
logger.debug("Resetting backend of sampler.")
sampler.backend.reset(sampler.nwalkers, sampler.ndim)
acor_time = AcorTime(old=np.inf, new=np.inf)
accepted = NumAccepted(old=0, new=sampler.backend.accepted.sum())
with Progress(*_get_columns(it=sampler.iteration), console=console) as progress:
task = progress.add_task(description=description, total=num_steps)
while sampler.iteration < (num_steps or np.inf):
for state in sampler.sample( # noqa: B007, B020
initial_state=state,
iterations=check_interval - sampler.iteration % check_interval,
thin_by=thin_by,
):
progress.update(task, advance=1)
acor_time.update(new=sampler.get_autocorr_time(tol=0).mean())
accepted.update(new=sampler.backend.accepted.sum())
history = update_history_table(
history=history,
history_file=history_file,
iteration=sampler.iteration,
acor_time=acor_time.new,
accepted_frac=(
accepted.newly_accepted / (check_interval * sampler.nwalkers)
),
max_log_prob=np.max(state.log_prob),
)
if num_steps is None and is_converged(
iteration=sampler.iteration,
acor_time=acor_time,
trust_factor=trust_factor,
relative_thresh=relative_thresh,
):
logger.info(f"Sampling converged after {sampler.iteration} steps.")
break
if history_file is not None:
history_file.with_suffix(".tmp").rename(history_file)
[docs]
class DummyPool:
"""Dummy class to allow for no multiprocessing."""
def __enter__(self) -> None:
"""Enter the context manager."""
...
def __exit__(self, *args) -> None:
"""Exit the context manager."""
...
[docs]
def get_pool(num_cores: int | None) -> Any | DummyPool: # type: ignore
"""Get a ``multiprocess(ing)`` pool or ``DummyPool``.
Returns a ``multiprocess(ing)`` pool with ``num_cores`` cores if ``num_cores`` is
not ``None``. Otherwise, a ``DummyPool`` is returned.
"""
return mp.Pool(num_cores) if num_cores is not None else DummyPool()
[docs]
def init_sampler(settings: SampleCLI, ndim: int, pool: Any) -> emcee.EnsembleSampler:
"""Initialize the ``emcee.EnsembleSampler`` with the given ``settings``."""
nwalkers = ndim * settings.sampling.walkers_per_dim
backend = get_hdf5_backend(
file_path=settings.sampling.storage_file,
dataset=settings.sampling.dataset,
nwalkers=nwalkers,
ndim=ndim,
)
return emcee.EnsembleSampler(
nwalkers=nwalkers,
ndim=ndim,
log_prob_fn=log_prob_fn,
kwargs={"inverse_temp": settings.sampling.inverse_temp},
moves=[(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2)],
backend=backend,
pool=pool,
blobs_dtype=[("log_prob", np.float64)],
parameter_names=list(MODEL.get_named_params().keys()),
)
[docs]
class SampleCLI(BaseCLI):
"""Use MCMC to infer distributions over model parameters from data."""
graph: GraphConfig
model: ModelConfig = ModelConfig()
distributions: dict[str, DistributionConfig] = Field(
default={},
description=(
"Mapping of model T-categories to predefined distributions over "
"diagnose times."
),
)
modalities: dict[str, ModalityConfig] = Field(
default={},
description=(
"Maps names of diagnostic modalities to their specificity/sensitivity."
),
)
data: DataConfig
sampling: SamplingConfig
[docs]
def cli_cmd(self) -> None:
"""Start the ``sample`` subcommand.
First, it will construct the model from the ``graph`` and ``model`` arguments.
Then, it will add distributions over diagnose times via the dictionary from
the ``distributions`` argument. It will also set sensitivity and specificity of
diagnostic modalities via the dictionary provided through the ``modalities``
argument. Finally, it will load the patient data as specified via the ``data``
argument.
When the model is constructed, an :py:class:`emcee.EnsembleSampler` is
initialized (see :py:func:`init_sampler`) and :py:func:`run_sampling` is
executed twice: once for the burn-in phase and once for the actual sampling
phase. The ``sampling`` argument provides all necessary settings for the
sampling.
"""
# as recommended in https://emcee.readthedocs.io/en/stable/tutorials/parallel/#
os.environ["OMP_NUM_THREADS"] = "1"
logger.debug(self.model_dump_json(indent=2))
# ugly, but necessary for pickling
global MODEL
MODEL = construct_model(self.model, self.graph)
MODEL = add_distributions(MODEL, self.distributions)
MODEL = add_modalities(MODEL, self.modalities)
MODEL.load_patient_data(**self.data.get_load_kwargs())
ndim = MODEL.get_num_dims()
# emcee does not support numpy's new random number generator yet.
np.random.seed(self.sampling.seed) # noqa: NPY002
with get_pool(self.sampling.cores) as pool:
sampler = init_sampler(settings=self, ndim=ndim, pool=pool)
run_sampling(
description="Burn-in phase",
sampler=sampler,
num_steps=self.sampling.burnin_steps,
check_interval=self.sampling.check_interval,
trust_factor=self.sampling.trust_factor,
relative_thresh=self.sampling.relative_thresh,
history_file=self.sampling.history_file,
)
run_sampling(
description="Sampling phase",
sampler=sampler,
num_steps=self.sampling.num_steps,
check_interval=self.sampling.num_steps,
reset_backend=True,
thin_by=self.sampling.thin_by,
)
if __name__ == "__main__":
main = assemble_main(settings_cls=SampleCLI, prog_name="sample")
main()