Source code for lyscripts.data.generate

"""Script to generate a synthetic dataset.

The generation is done by the :py:meth:`~lymph.models.Unilateral.draw_patients` method
of
the `lymph`_ package, which is why this requires the specification of a model
via the :py:class:`~lyscripts.configs.ModelConfig` class.

.. _lymph: https://lymph-model.readthedocs.io/
"""

import numpy as np
from loguru import logger
from lydata.utils import ModalityConfig
from pydantic import Field

from lyscripts.cli import assemble_main
from lyscripts.configs import (
    BaseCLI,
    DistributionConfig,
    GraphConfig,
    ModelConfig,
    add_distributions,
    add_modalities,
    construct_model,
)
from lyscripts.data.utils import save_table_to_csv


[docs] class GenerateCLI(BaseCLI): """Settings for the command-line interface.""" graph: GraphConfig model: ModelConfig = ModelConfig() distributions: dict[str, DistributionConfig] = Field( default={}, description=( "Mapping of model T-categories to predefined distributions over " "diagnose times." ), ) t_stages_dist: dict[str, float] = Field( description=( "Specify what fraction of generated patients should come from the " "respective T-Stage." ), ) modalities: dict[str, ModalityConfig] params: dict[str, float] num_patients: int = 200 output_file: str seed: int = 42
[docs] def model_post_init(self, __context) -> None: """Make sure distribution over T-stages is normalized.""" total = 0.0 for t_stage in self.distributions: if t_stage not in self.t_stages_dist: raise ValueError(f"Missing distribution for T-stage {t_stage}.") total += self.t_stages_dist[t_stage] if not np.isclose(total, 1.0): raise ValueError("Sum of T-stage distributions must be 1.") return super().model_post_init(__context)
[docs] def cli_cmd(self) -> None: """Run the ``generate`` command. Here, the command constructs a model from the settings provided via the arguments. It then generates a synthetic dataset using the :py:meth:`~lymph.models.Unilateral.draw_patients` from the `lymph`_ package. .. _lymph: https://lymph-model.readthedocs.io/ """ logger.debug(self.model_dump_json(indent=2)) model = construct_model(self.model, self.graph) model = add_distributions(model, self.distributions) model = add_modalities(model, self.modalities) model.set_params(**self.params) logger.info(f"Set parameters: {model.get_params(as_dict=True)}") synth_data = model.draw_patients( num=self.num_patients, stage_dist=list(self.t_stages_dist.values()), seed=self.seed, ) logger.info(f"Generated synthetic data with shape {synth_data.shape}") save_table_to_csv(file_path=self.output_file, table=synth_data)
if __name__ == "__main__": main = assemble_main(settings_cls=GenerateCLI, prog_name="data generate") main()