Source code for lyscripts.compute.priors

"""Given samples drawn during an MCMC round, compute the (prior) state distributions.

This is done for each sample and for a list of specified scenarios. The computation is
cached at a location specified by the ``--cache_dir`` argument using ``joblib``.
"""

from typing import Literal

import numpy as np
from loguru import logger
from pydantic import Field
from rich import progress

from lyscripts.cli import assemble_main
from lyscripts.compute.utils import BaseComputeCLI, HDF5FileStorage, get_cached
from lyscripts.configs import (
    DistributionConfig,
    GraphConfig,
    ModelConfig,
    add_distributions,
    construct_model,
)
from lyscripts.utils import console


[docs] def compute_priors( model_config: ModelConfig, graph_config: GraphConfig, dist_configs: dict[str, DistributionConfig], samples: np.ndarray, t_stages: list[int | str], t_stages_dist: list[float], mode: Literal["HMM", "BN"] = "HMM", progress_desc: str = "Computing priors from samples", ) -> np.ndarray: """Compute prior state distributions from the ``samples`` for the ``model``. This will call the ``model`` method :py:meth:`~lymph.types.Model.state_dist` for each of the ``samples``. The prior state distributions are computed for each of the ``t_stages`` and marginalized over using the ``t_stages_dist``. """ model = construct_model(model_config, graph_config) model = add_distributions(model, dist_configs) priors = [] for sample in progress.track( sequence=samples, description=progress_desc, total=len(samples), console=console, ): model.set_params(*sample) priors.append( sum( model.state_dist(t_stage=t, mode=mode) * p for t, p in zip(t_stages, t_stages_dist, strict=False) ), ) return np.stack(priors)
[docs] class PriorsCLI(BaseComputeCLI): """Compute the prior state distributions from MCMC samples.""" priors: HDF5FileStorage = Field(description="Storage for the computed priors.")
[docs] def cli_cmd(self) -> None: """Start the ``priors`` subcommand. Given a ``graph``, ``model``, ``distributions`` over diagnosis times, and MCMC samples loaded from the ``sampling`` argument, this command computes the prior state distributions for each of the specified ``scenarios``. Precomputing these state distributions is useful, because they largely only depend on T-stage and not on the diagnosis or involvement of interest. Hence, computing the :py:mod:`~lyscripts.compute.posteriors` and :py:mod:`~lyscripts.compute.risks` can be sped up. Note that this command will use `joblib`_ to cache its computations. .. _joblib: https://joblib.readthedocs.io/ """ logger.debug(self.model_dump_json(indent=2)) global_attrs = self.model_dump(include={"model", "graph", "distributions"}) self.priors.set_attrs(attrs=global_attrs, dataset="/") samples = self.sampling.load() cached_compute_priors = get_cached(compute_priors, self.cache_dir) num_scenarios = len(self.scenarios) for i, scenario in enumerate(self.scenarios): _fields = {"t_stages", "t_stages_dist", "mode"} prior_kwargs = scenario.model_dump(include=_fields) priors = cached_compute_priors( model_config=self.model, graph_config=self.graph, dist_configs=self.distributions, samples=samples, progress_desc=f"Computing priors for scenario {i + 1}/{num_scenarios}", **prior_kwargs, ) self.priors.save(values=priors, dataset=f"{i:03d}") self.priors.set_attrs(attrs=prior_kwargs, dataset=f"{i:03d}")
if __name__ == "__main__": main = assemble_main(settings_cls=PriorsCLI, prog_name="compute priors") main()