From 9d8debf2955048fab1712f453373615cff7e309f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Wei=C3=9Fenborn?= Date: Fri, 15 Oct 2021 16:58:51 +0200 Subject: [PATCH] Added model generators --- glotaran/project/generators/__init__.py | 1 + glotaran/project/generators/generator.py | 147 ++++++++++++++++++ .../test/test_genenerate_decay_model.py | 40 +++++ 3 files changed, 188 insertions(+) create mode 100644 glotaran/project/generators/__init__.py create mode 100644 glotaran/project/generators/generator.py create mode 100644 glotaran/project/generators/test/test_genenerate_decay_model.py diff --git a/glotaran/project/generators/__init__.py b/glotaran/project/generators/__init__.py new file mode 100644 index 000000000..0ccbd15f1 --- /dev/null +++ b/glotaran/project/generators/__init__.py @@ -0,0 +1 @@ +"""The glotaran generator package.""" diff --git a/glotaran/project/generators/generator.py b/glotaran/project/generators/generator.py new file mode 100644 index 000000000..81c8abb8a --- /dev/null +++ b/glotaran/project/generators/generator.py @@ -0,0 +1,147 @@ +"""The glotaran generator module.""" +from __future__ import annotations + +from typing import Callable + +from yaml import dump + +from glotaran.model import Model + + +def _generate_decay_model(nr_compartments: int, irf: bool, decay_type: str) -> dict: + """Generate a decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + decay_type : str + The dype of the decay + + Returns + ------- + dict : + The generated model dictionary. + """ + compartments = [f"species_{i+1}" for i in range(nr_compartments)] + rates = [f"decay.species_{i+1}" for i in range(nr_compartments)] + model = { + "megacomplex": { + f"megacomplex_{decay_type}_decay": { + "type": f"decay-{decay_type}", + "compartments": compartments, + "rates": rates, + }, + }, + "dataset": {"dataset_1": {"megacomplex": [f"megacomplex_{decay_type}_decay"]}}, + } + if irf: + model["dataset"]["dataset_1"]["irf"] = "gaussian_irf" # type:ignore[index] + model["irf"] = { + "gaussian_irf": {"type": "gaussian", "center": "irf.center", "width": "irf.width"}, + } + return model + + +def generate_parallel_model(nr_compartments: int = 1, irf: bool = False) -> dict: + """Generate a parallel decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + Returns + ------- + dict : + The generated model dictionary. + """ + return _generate_decay_model(nr_compartments, irf, "parallel") + + +def generate_sequential_model(nr_compartments: int = 1, irf: bool = False) -> dict: + """Generate a sequential decay model dictionary. + + Parameters + ---------- + nr_compartments : int + The number of compartments. + irf : bool + Whether to add a gaussian irf. + + Returns + ------- + dict : + The generated model dictionary. + """ + return _generate_decay_model(nr_compartments, irf, "sequential") + + +generators: dict[str, Callable] = { + "decay-parallel": generate_parallel_model, + "decay-sequential": generate_sequential_model, +} + +available_generators: list[str] = list(generators.keys()) + + +def generate_model(generator: str, **generator_arguments: dict) -> Model: + """Generate a model. + + Parameters + ---------- + generator : str + The generator to use. + generator_arguments : dict + Arguments for the generator. + + Returns + ------- + Model + The generated model + + Raises + ------ + ValueError + Raised when an unknown generator is specified. + """ + if generator not in generators: + raise ValueError( + f"Unknown model generator '{generator}'. " + f"Known generators are: {list(generators.keys())}" + ) + model = generators[generator](**generator_arguments) + return Model.from_dict(model) + + +def generate_model_yml(generator: str, **generator_arguments: dict) -> str: + """Generate a model as yml string. + + Parameters + ---------- + generator : str + The generator to use. + generator_arguments : dict + Arguments for the generator. + + Returns + ------- + str + The generated model yml string. + + Raises + ------ + ValueError + Raised when an unknown generator is specified. + """ + if generator not in generators: + raise ValueError( + f"Unknown model generator '{generator}'. " + f"Known generators are: {list(generators.keys())}" + ) + model = generators[generator](**generator_arguments) + return dump(model) diff --git a/glotaran/project/generators/test/test_genenerate_decay_model.py b/glotaran/project/generators/test/test_genenerate_decay_model.py new file mode 100644 index 000000000..670ef5647 --- /dev/null +++ b/glotaran/project/generators/test/test_genenerate_decay_model.py @@ -0,0 +1,40 @@ +import pytest + +from glotaran.project.generators.generator import generate_model + + +@pytest.mark.parametrize("megacomplex_type", ["parallel", "sequential"]) +@pytest.mark.parametrize("irf", [True, False]) +def test_generate_parallel_model(megacomplex_type: str, irf: bool): + nr_compartments = 5 + model = generate_model( + f"decay-{megacomplex_type}", + **{"nr_compartments": nr_compartments, "irf": irf}, # type:ignore[arg-type] + ) + print(model) # T001 + + assert ( + f"megacomplex_{megacomplex_type}_decay" in model.megacomplex # type:ignore[attr-defined] + ) + megacomplex = model.megacomplex[ # type:ignore[attr-defined] + f"megacomplex_{megacomplex_type}_decay" + ] + assert megacomplex.type == f"decay-{megacomplex_type}" + assert megacomplex.compartments == [f"species_{i+1}" for i in range(nr_compartments)] + assert [r.full_label for r in megacomplex.rates] == [ + f"decay.species_{i+1}" for i in range(nr_compartments) + ] + + assert "dataset_1" in model.dataset # type:ignore[attr-defined] + dataset = model.dataset["dataset_1"] # type:ignore[attr-defined] + assert dataset.megacomplex == [f"megacomplex_{megacomplex_type}_decay"] + if irf: + assert dataset.irf == "gaussian_irf" + assert "gaussian_irf" in model.irf # type:ignore[attr-defined] + assert ( + model.irf["gaussian_irf"].center.full_label # type:ignore[attr-defined] + == "irf.center" + ) + assert ( + model.irf["gaussian_irf"].width.full_label == "irf.width" # type:ignore[attr-defined] + )