Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrated entry points #27

Merged
merged 6 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions {{cookiecutter.project_name}}/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ classifiers = [
{%- if cookiecutter.backend == "hatch" %}
dynamic = ["version"]
{%- endif %}
dependencies = ["pybamm"]
dependencies = ["pybamm",]

[project.optional-dependencies]
dev = [
"pytest >=6",
"pytest-cov >=3",
"nox",
"nox[uv]",
"pre-commit",
"pytest-cookies",
santacodes marked this conversation as resolved.
Show resolved Hide resolved
]
docs = [
"sphinx",
Expand All @@ -76,6 +77,13 @@ Homepage = "{{ cookiecutter.url }}"
"Bug Tracker" = "{{ cookiecutter.url }}/issues"
Discussions = "{{ cookiecutter.url }}/discussions"
Changelog = "{{ cookiecutter.url }}/releases"

[project.entry-points."parameter_sets"]
Chen2020 = "{{ cookiecutter.__project_slug }}.parameters.input.Chen2020:get_parameter_values"

[project.entry-points."models"]
SPM = "{{ cookiecutter.__project_slug }}.models.input.SPM:SPM"

{# keep this line here for newline #}
{%- if cookiecutter.backend == "hatch" %}
[tool.hatch]
Expand All @@ -88,8 +96,8 @@ envs.default.dependencies = [
{# keep this line here for newline #}
{%- if cookiecutter.mypy %}
[tool.mypy]
python_version = "3.8"
strict = true
python_version = "3.11"
strict = false
warn_return_any = false
show_error_codes = true
enable_error_code = [
Expand All @@ -99,6 +107,9 @@ enable_error_code = [
]
disallow_untyped_defs = false
disallow_untyped_calls = false
ignore_missing_imports = true
allow_redefinition = true
disable_error_code = ["call-overload", "operator"]
{%- endif %}

[tool.coverage]
Expand All @@ -112,7 +123,7 @@ select = [
"E", "F", "W", # flake8
"B", # flake8-bugbear
"I", # isort
"ARG", # flake8-unused-arguments
#"ARG", # flake8-unused-arguments
"C4", # flake8-comprehensions
"EM", # flake8-errmsg
"ICN", # flake8-import-conventions
Expand All @@ -123,7 +134,7 @@ select = [
"PL", # pylint
"PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib
"RET", # flake8-return
#"RET", # flake8-return
"RUF", # Ruff-specific
"SIM", # flake8-simplify
"T20", # flake8-print
Expand All @@ -138,6 +149,11 @@ unfixable = [
"T20", # Removes print statements
"F841", # Removes unused variables
]
ignore = [
"E741", # Ambiguous variable name
"E501", # Line too long
"PLR2004", # Magic value used in comparison
]
line-length = 100
exclude = []
flake8-unused-arguments.ignore-variadic-names = true
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,17 @@
{%- endif %}

from ._version import version as __version__
import pybamm
from .entry_point import Model, parameter_sets, models
{# keep this line here for newline #}
{%- if cookiecutter.mypy %}
__all__: tuple[str] = ("__version__",)
__all__: list[str] = [
{%- else %}
__all__ = ("__version__",)
__all__ = [
{%- endif %}
"__version__",
"pybamm",
"parameter_sets",
"Model",
"models",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
This code is adopted from the PyBaMM project under the BSD-3-Clause

Copyright (c) 2018-2024, the PyBaMM team.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""


import importlib.metadata
import sys
import textwrap
from collections.abc import Mapping
from typing import Callable

class EntryPoint(Mapping):
"""
Dict-like interface for accessing parameter sets and models through entry points in cookiecutter template.
Access via :py:data:`pybamm_cookiecutter.parameter_sets` for parameter_sets
Access via :py:data:`pybamm_cookiecutter.Model` for Models

Examples
--------
Listing available parameter sets:
>>> import pybamm_cookiecutter
>>> list(pybamm_cookiecutter.parameter_sets)
['Chen2020', ...]
>>> list(pybamm_cookiecutter.models)
['SPM', ...]

Get the docstring for a parameter set/model:


>>> print(pybamm_cookiecutter.parameter_sets.get_docstring("Ai2020"))
<BLANKLINE>
Parameters for the Enertech cell (Ai2020), from the papers :footcite:t:`Ai2019`,
:footcite:t:`rieger2016new` and references therein.
...

>>> print(pybamm_cookiecutter.models.get_docstring("SPM"))
<BLANKLINE>
Single Particle Model (SPM) model of a lithium-ion battery, from :footcite:t:`Marquis2019`. This class differs from the :class:`pybamm.lithium_ion.SPM` model class in that it shows the whole model in a single class. This comes at the cost of flexibility in combining different physical effects, and in general the main SPM class should be used instead.
...
See also: :ref:`adding-parameter-sets`
"""

_instances = 0
def __init__(self, group):
"""Dict of entry points for parameter sets or models, lazily load entry points as"""
if not hasattr(self, 'initialized'): # Ensure __init__ is called once per instance
self.initialized = True
EntryPoint._instances += 1
self._all_entries = dict()
self.group = group
for entry_point in self.get_entries(self.group):
self._all_entries[entry_point.name] = entry_point

@staticmethod
def get_entries(group_name):
"""Wrapper for the importlib version logic"""
if sys.version_info < (3, 10): # pragma: no cover
return importlib.metadata.entry_points()[group_name]
else:
return importlib.metadata.entry_points(group=group_name)

def __new__(cls, group):
"""Ensure only two instances of entry points exist, one for parameter sets and the other for models"""
if EntryPoint._instances < 2:
cls.instance = super().__new__(cls)
return cls.instance

def __getitem__(self, key) -> dict:
return self._load_entry_point(key)()

def _load_entry_point(self, key) -> Callable:
"""Check that ``key`` is a registered ``parameter_sets`` or ``models` ,
and return the entry point for the parameter set/model, loading it needed."""
if key not in self._all_entries:
raise KeyError(f"Unknown parameter set or model: {key}")
ps = self._all_entries[key]
try:
ps = self._all_entries[key] = ps.load()
except AttributeError:
pass
return ps

def __iter__(self):
return self._all_entries.__iter__()

def __len__(self) -> int:
return len(self._all_entries)

def get_docstring(self, key):
"""Return the docstring for the ``key`` parameter set or model"""
return textwrap.dedent(self._load_entry_point(key).__doc__)

def __getattribute__(self, name):
try:
return super().__getattribute__(name)
except AttributeError as error:
raise error

#: Singleton Instance of :class:ParameterSets """
parameter_sets = EntryPoint(group="parameter_sets")

#: Singleton Instance of :class:ModelEntryPoints"""
models = EntryPoint(group="models")

def Model(model:str):
"""
Returns the loaded model object

Parameters
----------
model : str
The model name or author name of the model mentioned at the model entry point.
Returns
-------
pybamm.model
Model object of the initialised model.
Examples
--------
Listing available models:
>>> import pybamm_cookiecutter
>>> list(pybamm_cookiecutter.models)
['SPM', ...]
>>> pybamm_cookiecutter.Model('Author/Year')
<pybamm_cookiecutter.models.input.SPM.SPM object>
"""
return models[model]
Loading
Loading