Source code for lyscripts.data.filter
"""Filter a dataset according to some common criteria.
This is essentially a command line interface to building a
:py:class:`query object <lydata.querier.Q>` and applying it to the dataset.
"""
from pathlib import Path
from typing import Literal
from loguru import logger
from lydata import Q
from pydantic import Field
from pydantic_settings import CliImplicitFlag
from lyscripts.cli import assemble_main
from lyscripts.configs import BaseCLI, DataConfig
from lyscripts.data.utils import save_table_to_csv
[docs]
class FilterCLI(BaseCLI):
"""In- or exclude patients where a certain column fulfills a certain condition."""
input: DataConfig
include: CliImplicitFlag[bool] = Field(
False,
description="Include patients where the condition is met (default: exclude).",
)
column: list[str] | str = Field(
description=(
"The column to filter by. May be a tuple of three strings, since data "
"has a three-level header. If it is only one string, the lydata package "
"tries to map that to a three-level header."
),
)
operator: Literal["==", "!=", ">", "<", ">=", "<=", "in", "contains"] = Field(
description="The operator to use for comparison.",
)
value: float | int | str = Field(description="The value to compare against.")
output_file: Path = Field(description="The path to save the filtered dataset to.")
[docs]
def model_post_init(self, __context):
"""Cast to ``float``, if not possible ``int``, if not possible ``str``."""
if isinstance(self.column, list):
if len(self.column) == 1:
self.column = self.column[0]
elif len(self.column) == 3:
self.column = tuple(self.column)
else:
raise ValueError(
"The column attribute must be an iterable of three strings or a "
f"single string, but it is {self.column}.",
)
try:
self.value = float(self.value)
return super().model_post_init(__context)
except ValueError:
pass
try:
self.value = int(self.value)
return super().model_post_init(__context)
except ValueError:
pass
return super().model_post_init(__context)
[docs]
def cli_cmd(self):
"""Execute the ``filter`` command.
This command uses the :py:class:`~lydata.querier.Q` objects of the `lydata`_
library to filter the dataset according to the given criteria.
.. _lydata: https://lydata.readthedocs.io
"""
logger.debug(self.model_dump_json(indent=2))
data = self.input.load()
query = Q(
column=self.column,
operator=self.operator,
value=self.value,
)
logger.debug(f"Created query object: {query}")
mask = query.execute(data)
if self.include:
filtered = data[mask]
logger.info(f"Keeping {sum(mask)} of {len(data)} patients.")
else:
filtered = data[~mask]
logger.info(f"Excluding {sum(mask)} of {len(data)} patients.")
save_table_to_csv(file_path=self.output_file, table=filtered)
if __name__ == "__main__":
main = assemble_main(settings_cls=FilterCLI, prog_name="filter")
main()