From 02bd211b971bdfcf7f786fa039907a79b15d8366 Mon Sep 17 00:00:00 2001 From: Varchas Gopalaswamy <2359219+varchasgopalaswamy@users.noreply.github.com> Date: Fri, 7 Jun 2024 18:45:14 -0400 Subject: [PATCH] Fixes for CI (#5) * ci should work now * fix a pyright error * ci uses pre-commit for checks * fixed requirements * fix dependency * add all the optional packaged to CI --- .github/workflows/python-test.yml | 9 +- .pre-commit-config.yaml | 53 +++-- hooks/fix_line_endings.py | 56 +++++ hooks/generate_init.py | 48 ++++ prepper/__init__.py | 52 ++++- prepper/__init__.pyi | 80 +++++++ prepper/caching.py | 56 ++--- prepper/enums.py | 1 - prepper/exceptions.py | 7 + prepper/exportable.py | 142 +++++------- prepper/io_handlers.py | 242 +++++++++----------- prepper/tests/__init__.py | 27 +++ prepper/tests/__init__.pyi | 50 ++++ prepper/tests/test_IO.py | 11 +- prepper/tests/test_decorators.py | 20 +- prepper/utils.py | 51 ++--- pyproject.toml | 58 +---- ruff.toml | 80 +++++++ typings/periodictable/__init__.pyi | 7 - typings/periodictable/activation.pyi | 14 +- typings/periodictable/core.pyi | 62 ++--- typings/periodictable/covalent_radius.pyi | 1 - typings/periodictable/cromermann.pyi | 5 - typings/periodictable/crystal_structure.pyi | 1 - typings/periodictable/density.pyi | 3 - typings/periodictable/fasta.pyi | 6 - typings/periodictable/formulas.pyi | 35 +-- typings/periodictable/magnetic_ff.pyi | 13 +- typings/periodictable/mass.pyi | 3 - typings/periodictable/nsf.pyi | 43 +--- typings/periodictable/nsf_resonances.pyi | 2 - typings/periodictable/plot.pyi | 1 - typings/periodictable/util.pyi | 8 +- typings/periodictable/xsf.pyi | 17 +- 34 files changed, 688 insertions(+), 576 deletions(-) create mode 100755 hooks/fix_line_endings.py create mode 100755 hooks/generate_init.py create mode 100644 prepper/__init__.pyi create mode 100644 prepper/exceptions.py create mode 100644 prepper/tests/__init__.pyi create mode 100644 ruff.toml diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index 1dfc228..4c08e64 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -19,11 +19,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pylint pytest pytest-cov hypothesis - pip install . - - name: Lint with pylint + pip install pre-commit pytest pytest-cov hypothesis + pre-commit install + pip install .[CI] + - name: Run pre-commit checks run: | - pylint --fail-under=8 prepper + pre-commit run --all-files - name: Test with pytest run: | pytest --cov=prepper --cov-report=xml --cov-report=html diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4826f37..8e58de1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,32 +1,37 @@ repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.275 +- repo: local hooks: - - id: ruff - args: [ --fix, --exit-non-zero-on-fix ] - -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - name: isort (python) - -- repo: https://github.com/psf/black.git - rev: 23.1.0 - hooks: - - id: black - language_version: python3.9 + + - id: generate-init + name: Generates __init__.py files + language: python + entry: python hooks/generate_init.py + always_run: true + require_serial: true + additional_dependencies: ["mkinit", "ruff"] + - id: fix-line-endings + name: Convert CRLF/CR endings to LF + language: python + require_serial: true + entry: python hooks/fix_line_endings.py + types: ["text"] + - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - - id: end-of-file-fixer - - id: fix-encoding-pragma - - id: trailing-whitespace - - id: check-case-conflict - id: check-executables-have-shebangs - - id: check-merge-conflict - id: check-symlinks - - id: debug-statements - - id: mixed-line-ending + +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.4.6 + hooks: + #Run the formatter. + - id: ruff-format + types_or: [ python, pyi, jupyter ] + #Run the linter. + - id: ruff + types_or: [ python, pyi, jupyter ] + args: [ --fix, --exit-non-zero-on-fix ] \ No newline at end of file diff --git a/hooks/fix_line_endings.py b/hooks/fix_line_endings.py new file mode 100755 index 0000000..22b2dc3 --- /dev/null +++ b/hooks/fix_line_endings.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +from pathlib import Path +import sys + + +def main(source_path: str) -> bool: + """ + Main entry point of the script. + + Parameters + ---------- + function : Callable + Function to execute for the specified validation type. + source_path : str + Source path representing path to a file/directory. + output_format : str + Output format of the error message. + file_extensions_to_check : str + Comma separated values of what file extensions to check. + excluded_file_paths : str + Comma separated values of what file paths to exclude during the check. + + Returns + ------- + bool + True if found any patterns are found related to the given function. + + Raises + ------ + ValueError + If the `source_path` is not pointing to existing file/directory. + """ + + for file_path in source_path: + with Path(file_path).open("r", encoding="utf-8") as file_obj: + file_text = file_obj.read() + + invalid_ending = "\r" in file_text + if invalid_ending: + with Path(file_path).open("w", encoding="utf-8") as file_obj: + file_obj.write(file_text) + + return invalid_ending + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="CR/CRLF -> LF converter.") + + parser.add_argument("paths", nargs="*", help="Source paths of files to check.") + + args = parser.parse_args() + + sys.exit(main(source_path=args.paths)) diff --git a/hooks/generate_init.py b/hooks/generate_init.py new file mode 100755 index 0000000..dae8cbb --- /dev/null +++ b/hooks/generate_init.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import subprocess +import sys + +from mkinit import static_mkinit + + +def ruff_format(): ... +def make_init(): + options = { + "with_attrs": True, + "with_mods": True, + "with_all": True, + "relative": True, + "lazy_import": False, + "lazy_loader": True, + "lazy_loader_typed": True, + "lazy_boilerplate": None, + "use_black": False, + } + + static_mkinit.autogen_init( + "prepper", + respect_all=True, + options=options, + dry=False, + diff=False, + recursive=True, + ) + subprocess.run(["ruff", "format"]) + + +if __name__ == "__main__": + make_init() + + changed_files1 = subprocess.run( + ["git", "diff", "--name-only", "--diff-filter=ACM", "--exit-code"] + ) + changed_files2 = subprocess.run( + ["git", "ls-files", "--exclude-standard", "--others"], capture_output=True + ) + retcode = changed_files1.returncode + changed_files2.returncode + retcode += len(changed_files2.stderr) + retcode += len(changed_files2.stdout) + + sys.exit(retcode) diff --git a/prepper/__init__.py b/prepper/__init__.py index 5bf7565..235e190 100644 --- a/prepper/__init__.py +++ b/prepper/__init__.py @@ -1,23 +1,53 @@ -# -*- coding: utf-8 -*- from __future__ import annotations -import h5py +__protected__ = ["utils", "io_handlers", "enums", "utils"] + + +# +import lazy_loader + + +__getattr__, __dir__, __all__ = lazy_loader.attach_stub(__name__, __file__) __all__ = [ + "ExportableClassMixin", + "H5StoreException", + "NATIVE_DTYPES", + "NotASaveableClass", + "SimpleSaveableClass", + "SimpleSaveableClass2", + "break_key", "cached_property", + "caching", + "ensure_required_attributes", + "enums", + "exceptions", + "exportable", + "io_handlers", "local_cache", - "ExportableClassMixin", + "make_cache_name", + "req_class_attrs", + "req_dataset_attrs", + "req_none_attrs", + "roundtrip", "saveable_class", + "test_IO", + "test_cached_property", + "test_decorators", + "test_local_cache", + "test_saveable_class", + "test_with_float_list", + "test_with_floats", + "test_with_heterogenous_list", + "test_with_int_list", + "test_with_ints", + "test_with_str_list", + "tests", + "utils", ] +# +import h5py # Make h5py write groups in order h5py.get_config().track_order = True - - -class H5StoreException(Exception): - "An exception for when the HDF5 store does not meet spec" - - -from .caching import cached_property, local_cache # noqa E402 -from .exportable import ExportableClassMixin, saveable_class # noqa E402 diff --git a/prepper/__init__.pyi b/prepper/__init__.pyi new file mode 100644 index 0000000..809c681 --- /dev/null +++ b/prepper/__init__.pyi @@ -0,0 +1,80 @@ +from . import caching +from . import enums +from . import exceptions +from . import exportable +from . import io_handlers +from . import tests +from . import utils + +from .caching import ( + break_key, + cached_property, + local_cache, + make_cache_name, +) +from .exceptions import ( + H5StoreException, +) +from .exportable import ( + ExportableClassMixin, + saveable_class, +) +from .tests import ( + NATIVE_DTYPES, + NotASaveableClass, + SimpleSaveableClass, + SimpleSaveableClass2, + ensure_required_attributes, + req_class_attrs, + req_dataset_attrs, + req_none_attrs, + roundtrip, + test_IO, + test_cached_property, + test_decorators, + test_local_cache, + test_saveable_class, + test_with_float_list, + test_with_floats, + test_with_heterogenous_list, + test_with_int_list, + test_with_ints, + test_with_str_list, +) + +__all__ = [ + "ExportableClassMixin", + "H5StoreException", + "NATIVE_DTYPES", + "NotASaveableClass", + "SimpleSaveableClass", + "SimpleSaveableClass2", + "break_key", + "cached_property", + "caching", + "ensure_required_attributes", + "enums", + "exceptions", + "exportable", + "io_handlers", + "local_cache", + "make_cache_name", + "req_class_attrs", + "req_dataset_attrs", + "req_none_attrs", + "roundtrip", + "saveable_class", + "test_IO", + "test_cached_property", + "test_decorators", + "test_local_cache", + "test_saveable_class", + "test_with_float_list", + "test_with_floats", + "test_with_heterogenous_list", + "test_with_int_list", + "test_with_ints", + "test_with_str_list", + "tests", + "utils", +] diff --git a/prepper/caching.py b/prepper/caching.py index 0225257..5ea1486 100644 --- a/prepper/caching.py +++ b/prepper/caching.py @@ -1,24 +1,20 @@ -# -*- coding: utf-8 -*- from __future__ import annotations -import functools from collections.abc import Callable +import functools from functools import update_wrapper, wraps from typing import ( Any, - Dict, + Concatenate, Generic, - List, - overload, - Tuple, - Type, + Self, TypeVar, - Union, + overload, ) from joblib import hash as joblib_hash from numpy import ndarray -from typing_extensions import Concatenate, ParamSpec, Self +from typing_extensions import ParamSpec __all__ = [ "break_key", @@ -62,7 +58,7 @@ def __hash__(self): return self.hashvalue -def break_key(key: Any) -> Tuple[List[Any], Dict[str, Any]]: +def break_key(key: Any) -> tuple[list[Any], dict[str, Any]]: "Breaks a function cache key into the args and kwargs" args = [] kwargs = {} @@ -97,10 +93,7 @@ def _make_key(args, kwds): if kwds: key += (KWD_SENTINEL,) for k, v in kwds.items(): - if isinstance(v, ndarray): - v2 = tuple(v.tolist()) - else: - v2 = v + v2 = tuple(v.tolist()) if isinstance(v, ndarray) else v key += (k, v2) return _HashedSeq(key) @@ -145,16 +138,12 @@ def __init__(self, func: Callable[[Instance], Value]): self.func = func @overload - def __get__(self, instance: Instance, owner: object) -> Value: - ... + def __get__(self, instance: Instance, owner: object) -> Value: ... @overload - def __get__(self, instance: None, owner: object) -> Self: - ... + def __get__(self, instance: None, owner: object) -> Self: ... - def __get__( - self, instance: Union[Instance, None], owner: object - ) -> Union[Self, Value]: + def __get__(self, instance: Instance | None, owner: object) -> Self | Value: if instance is None: return self @@ -177,30 +166,24 @@ class local_cache(Generic[Instance, Arguments, Value]): user_func: Callable[Concatenate[Instance, Arguments], Value] - def __init__( - self, wrapped_func: Callable[Concatenate[Instance, Arguments], Value] - ): + def __init__(self, wrapped_func: Callable[Concatenate[Instance, Arguments], Value]): self.user_func = wrapped_func @overload def __get__( - self, instance: Instance, owner: Type[Instance] - ) -> Callable[Arguments, Value]: - ... + self, instance: Instance, owner: type[Instance] + ) -> Callable[Arguments, Value]: ... @overload def __get__( - self, instance: None, owner: Type[Instance] - ) -> Callable[Concatenate[Instance, Arguments], Value]: - ... + self, instance: None, owner: type[Instance] + ) -> Callable[Concatenate[Instance, Arguments], Value]: ... - def __get__(self, instance: Instance | None, owner: Type[Instance]): + def __get__(self, instance: Instance | None, owner: type[Instance]): if instance is None: return self.user_func else: - partial_function: Callable[ - Arguments, Value - ] = functools.update_wrapper( + partial_function: Callable[Arguments, Value] = functools.update_wrapper( functools.partial(_cache_wrapper(self.user_func), instance), self.user_func, ) # type: ignore @@ -215,6 +198,5 @@ def __set__(self, obj, value): if isinstance(key, _HashedSeq): obj.__dict__[fname][key] = return_value else: - raise ValueError( - f"Can't assign {value} to the cache of {self.user_func.__qualname__}!" - ) + msg = f"Can't assign {value} to the cache of {self.user_func.__qualname__}!" + raise TypeError(msg) diff --git a/prepper/enums.py b/prepper/enums.py index 1deec4d..17ba955 100644 --- a/prepper/enums.py +++ b/prepper/enums.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from __future__ import annotations from typing import TYPE_CHECKING diff --git a/prepper/exceptions.py b/prepper/exceptions.py new file mode 100644 index 0000000..92221ee --- /dev/null +++ b/prepper/exceptions.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +__all__ = ["H5StoreException"] + + +class H5StoreException(Exception): + "An exception for when the HDF5 store does not meet spec" diff --git a/prepper/exportable.py b/prepper/exportable.py index ab63042..56d50c8 100644 --- a/prepper/exportable.py +++ b/prepper/exportable.py @@ -1,32 +1,33 @@ -# -*- coding: utf-8 -*- from __future__ import annotations +from abc import ABCMeta +from collections.abc import Callable import copy import datetime import importlib.metadata import inspect -import os +from inspect import Parameter, signature +from pathlib import Path import shutil import tempfile import traceback +from typing import Any, TypeVar import uuid import warnings -from abc import ABCMeta -from inspect import Parameter, signature -from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar import h5py import joblib import loguru import numpy as np -from prepper import cached_property, H5StoreException +from prepper import H5StoreException, cached_property from prepper.caching import break_key, make_cache_name from prepper.enums import H5StoreTypes from prepper.utils import check_equality __all__ = [ "ExportableClassMixin", + "saveable_class", ] @@ -36,10 +37,10 @@ class ExportableClassMixin(metaclass=ABCMeta): """ - _constructor_args: Dict[str, Any] + _constructor_args: dict[str, Any] api_version: float - _exportable_attributes: List[str] - _exportable_functions: List[str] + _exportable_attributes: list[str] + _exportable_functions: list[str] def __copy__(self): return self.__class__(**self._constructor_args) @@ -56,18 +57,17 @@ def __new__(cls, *args, **kwargs): instance._constructor_args = {} try: instance.__init__(*args, **kwargs) - except Exception as e: + except Exception: loguru.logger.error(f"Failed to initialize {cls.__name__}!") - raise e + raise sig = signature(instance.__init__) bound_args = sig.bind(*args, **kwargs) # bound_args.apply_defaults() for _, (key, value) in enumerate(bound_args.arguments.items()): if sig.parameters[key].kind == Parameter.POSITIONAL_ONLY: - raise ValueError( - "Cannot save arguments that are positional only!" - ) + msg = "Cannot save arguments that are positional only!" + raise ValueError(msg) if sig.parameters[key].kind == Parameter.VAR_KEYWORD: for kwkey, kwvalue in value.items(): instance._constructor_args[kwkey] = kwvalue @@ -126,9 +126,11 @@ def initialized_from_file(self): return getattr(self, "_initialized_from_file", False) @classmethod - def from_hdf5(cls, path, group="/"): - if not os.path.exists(path): - raise FileNotFoundError(f"Could not find file {path}") + def from_hdf5(cls, path: Path, group="/"): + path = Path(path) + if not Path.exists(path): + msg = f"Could not find file {path}" + raise FileNotFoundError(msg) with h5py.File(path, mode="r", track_order=True) as hdf5_file: if group not in hdf5_file: @@ -142,7 +144,7 @@ def from_hdf5(cls, path, group="/"): path, f"{group}/{cls.__name__}" ) if init_kw_type != H5StoreTypes.ClassConstructor: - raise H5StoreException() + raise H5StoreException else: init_kws = {} instance = cls(**init_kws) @@ -161,9 +163,8 @@ def _read_hdf5_contents(self, file, group): base = hdf5_file[group] entry_type = ExportableClassMixin._get_group_type(base) if entry_type != H5StoreTypes.PythonClass: - raise ValueError( - f"_read_hdf5_contents was called on a HDF5 group {group} that is not a python class spec!" - ) + msg = f"_read_hdf5_contents was called on a HDF5 group {group} that is not a python class spec!" + raise ValueError(msg) try: class_name = read_h5_attr(base, "class") # type: ignore @@ -194,52 +195,48 @@ def _read_hdf5_contents(self, file, group): self._initialized_from_file = True @staticmethod - def _load_h5_entry(file: str, group: str) -> Tuple[H5StoreTypes, Any]: + def _load_h5_entry(file: Path, group: str) -> tuple[H5StoreTypes, Any]: from prepper.io_handlers import load_custom_h5_type with h5py.File(file, mode="r", track_order=True) as hdf5_file: if group not in hdf5_file: - raise FileNotFoundError( - f"Could not find {group} in the cached result!" - ) + msg = f"Could not find {group} in the cached result!" + raise FileNotFoundError(msg) entry = hdf5_file[group] entry_type = ExportableClassMixin._get_group_type(entry) return entry_type, load_custom_h5_type(file, group, entry_type) - def to_hdf5(self, path): + def to_hdf5(self, path: Path): """ Save this object to an h5 file """ - - if os.path.exists(path): + path = Path(path) + if Path.exists(path): loguru.logger.warning(f"HDF5 file {path} exists... overwriting.") - if not os.path.exists(os.path.dirname(path)): - raise FileNotFoundError( - f"The parent directory for {path} does not exist!" - ) + if not Path.exists(Path(path).parent): + msg = f"The parent directory for {path} does not exist!" + raise FileNotFoundError(msg) with tempfile.TemporaryDirectory() as temp_dir: - temp_file = os.path.join(temp_dir, str(uuid.uuid1())) + temp_file = Path(temp_dir).joinpath(str(uuid.uuid1())) file = h5py.File(temp_file, "w") file.close() with warnings.catch_warnings(): warnings.simplefilter("ignore", category=FutureWarning) - warnings.simplefilter( - "ignore", category=np.VisibleDeprecationWarning - ) + warnings.simplefilter("ignore", category=np.VisibleDeprecationWarning) self._write_hdf5_contents( - temp_file, group="/", existing_groups={} + Path(temp_file), group="/", existing_groups={} ) shutil.copyfile(src=temp_file, dst=path) def _write_hdf5_contents( self, - file: str, + file: Path, group: str, - existing_groups: Dict[str, Any], + existing_groups: dict[str, Any], attributes=None, ): from prepper.io_handlers import dump_class_constructor, write_h5_attr @@ -273,9 +270,7 @@ def _write_hdf5_contents( # Store the class constructor arguments, if applicable if len(self._constructor_args) > 0: - existing_groups = dump_class_constructor( - file, group, self, existing_groups - ) + existing_groups = dump_class_constructor(file, group, self, existing_groups) # Save the attributes, if populated for symbol in self._exportable_attributes: @@ -314,9 +309,7 @@ def _write_hdf5_contents( for idx, (key, value) in enumerate(function_cache.items()): pad_number = str(idx + 1).zfill(pad_digits) key_args, key_kwargs = break_key(key) - call_group_name = ( - f"{group}/{fname}/{fname}_call{pad_number}" - ) + call_group_name = f"{group}/{fname}/{fname}_call{pad_number}" arg_group_name = f"{call_group_name}/args" kwarg_group_name = f"{call_group_name}/kwargs" value_group_name = f"{call_group_name}/value" @@ -346,11 +339,11 @@ def _write_hdf5_contents( @staticmethod def _dump_h5_entry( - file: str, + file: Path, entry_name: str, value: Any, - existing_groups: Dict[str, Any], - attributes: Dict[str, Any] | None = None, + existing_groups: dict[str, Any], + attributes: dict[str, Any] | None = None, ): from prepper.io_handlers import dump_custom_h5_type, write_h5_attr @@ -374,9 +367,7 @@ def _dump_h5_entry( try: write_h5_attr(new_entry, k, v) # type: ignore except H5StoreException: - msg = ( - f"Failed to write attribute {k} to group {new_entry}!" - ) + msg = f"Failed to write attribute {k} to group {new_entry}!" loguru.logger.error(msg) return existing_groups @@ -419,25 +410,24 @@ def _get_bound_name(file, groupname): def saveable_class( api_version: float, - attributes: List[str] | None = None, - functions: List[str] | None = None, -) -> Callable[[Type[E]], Type[E]]: + attributes: list[str] | None = None, + functions: list[str] | None = None, +) -> Callable[[type[E]], type[E]]: if attributes is None: attributes = [] if functions is None: functions = [] - def decorator(cls: Type[E]) -> Type[E]: + def decorator(cls: type[E]) -> type[E]: if not issubclass(cls, ExportableClassMixin): - raise ValueError( - "Only subclasses of ExportableClassMixin can be decorated with saveable_class" - ) + msg = "Only subclasses of ExportableClassMixin can be decorated with saveable_class" + raise TypeError(msg) - attribute_names: List[str] = [] - function_names: List[str] = [] + attribute_names: list[str] = [] + function_names: list[str] = [] - exportable_functions: List[str] = [] - exportable_attributes: List[str] = [] + exportable_functions: list[str] = [] + exportable_attributes: list[str] = [] for parent in reversed(inspect.getmro(cls)): if hasattr(parent, "_exportable_attributes"): @@ -448,29 +438,23 @@ def decorator(cls: Type[E]) -> Type[E]: for fcn in parent._exportable_functions: # type: ignore bound_class, symbol = fcn.split(".") function_names.append(symbol) - for attr in attributes: - attribute_names.append(attr) - - for fcn in functions: - function_names.append(fcn) + attribute_names += attributes + function_names += functions for symbol in attribute_names: if not hasattr(cls, symbol): - raise ValueError( - f"{cls} and its parents does not have property/attribute {symbol} at runtime. Dynamically added attributes are not supported." - ) + msg = f"{cls} and its parents does not have property/attribute {symbol} at runtime. Dynamically added attributes are not supported." + raise ValueError(msg) try: exportable_attributes.append(getattr(cls, symbol).__qualname__) except AttributeError: - raise ValueError( - f"{cls}.{symbol} is a property. Saving properties is not supported as they dont have __dict__ entries. Make {symbol} a cached property instead." - ) from None + msg = f"{cls}.{symbol} is a property. Saving properties is not supported as they dont have __dict__ entries. Make {symbol} a cached property instead." + raise ValueError(msg) from None for symbol in function_names: if not hasattr(cls, symbol): - raise ValueError( - f"{cls} and its parents does not have function {symbol}" - ) + msg = f"{cls} and its parents does not have function {symbol}" + raise ValueError(msg) exportable_functions.append(getattr(cls, symbol).__qualname__) cls._exportable_functions = list(set(exportable_functions)) @@ -500,13 +484,13 @@ def test_string(self): @cached_property def test_array(self): - return np.random.random(size=(1000, 1000)) + return np.random.random(size=(1000, 1000)) # noqa: NPY002 test_instance = SimpleSaveableClass() _ = test_instance.test_array _ = test_instance.test_string with tempfile.NamedTemporaryFile() as tmp: - test_instance.to_hdf5(tmp.name) + test_instance.to_hdf5(Path(tmp.name)) - new_instanace = SimpleSaveableClass.from_hdf5(tmp.name) + new_instanace = SimpleSaveableClass.from_hdf5(Path(tmp.name)) assert test_instance.test_string == new_instanace.test_string diff --git a/prepper/io_handlers.py b/prepper/io_handlers.py index b6d1213..514971c 100644 --- a/prepper/io_handlers.py +++ b/prepper/io_handlers.py @@ -1,23 +1,24 @@ -# -*- coding: utf-8 -*- from __future__ import annotations +from collections.abc import Iterable, Sequence +import contextlib import datetime +from enum import Enum import importlib import numbers +from pathlib import Path import re import tempfile import traceback -from collections.abc import Iterable -from enum import Enum -from typing import Any, Dict, List, Sequence, Tuple, Type, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Any import h5py import loguru import numpy as np -from prepper import H5StoreException from prepper.caching import _HashedSeq, _make_key from prepper.enums import H5StoreTypes +from prepper.exceptions import H5StoreException from prepper.exportable import ExportableClassMixin from prepper.utils import check_equality, get_element_from_number_and_weight @@ -32,7 +33,7 @@ pt = None try: - from auto_uncertainties import nominal_values, Uncertainty + from auto_uncertainties import Uncertainty, nominal_values except ImportError: Uncertainty = None nominal_values = None @@ -72,7 +73,7 @@ _EMPTY_TYPE_SENTINEL = "__python_Empty_sentinel__" PYTHON_BASIC_TYPES = (int, float, str) -TYPES_TO_SKIP_DUPLICATE_CHECKING = PYTHON_BASIC_TYPES + (bool,) +TYPES_TO_SKIP_DUPLICATE_CHECKING = (*PYTHON_BASIC_TYPES, bool) if az is not None: TYPES_TO_SKIP_DUPLICATE_CHECKING += (az.InferenceData,) @@ -95,7 +96,7 @@ def get_hdf5_compression(): return HDF5_COMPRESSION -def set_hdf5_compression(compression: Dict[str, Any]): +def set_hdf5_compression(compression: dict[str, Any]): global HDF5_COMPRESSION HDF5_COMPRESSION = compression @@ -113,15 +114,13 @@ def write_h5_attr(base: h5py.Group, name: str, value: Any): try: base.attrs[name] = np.asarray(value) except TypeError as exc2: - raise H5StoreException( - f"Could not write attribute {name} with type {type(value)}!" - ) from exc2 + msg = f"Could not write attribute {name} with type {type(value)}!" + raise H5StoreException(msg) from exc2 except KeyError: if name in base.attrs: pass - raise H5StoreException( - f"Could not write attribute {name} with type {type(value)}!" - ) from exc + msg = f"Could not write attribute {name} with type {type(value)}!" + raise H5StoreException(msg) from exc except KeyError: if name in base.attrs: pass @@ -156,7 +155,7 @@ def decorator(func): def dump_custom_h5_type( - file: str, group: str, value: Any, existing_groups: Dict[str, Any] + file: Path, group: str, value: Any, existing_groups: dict[str, Any] ): writers = {} writers.update(CUSTOM_H5_WRITERS) @@ -171,11 +170,12 @@ def dump_custom_h5_type( except Exception: is_equal = False - if is_equal: - if isinstance(v, type(value)): - class_already_written = True - clone_group = k # This is the group that this class is already written to - break + if is_equal and isinstance(v, type(value)): + class_already_written = True + clone_group = ( + k # This is the group that this class is already written to + ) + break if class_already_written: # This class has already been written to the file, so we just need to write a reference to it with h5py.File(file, mode="a", track_order=True) as hdf5_file: @@ -190,9 +190,7 @@ def dump_custom_h5_type( loguru.logger.error(msg, exc_info=True) is_valid = False if is_valid: - attrs, existing_groups = writer( - file, group, value, existing_groups - ) + attrs, existing_groups = writer(file, group, value, existing_groups) with h5py.File(file, mode="a", track_order=True) as hdf5_file: try: @@ -203,29 +201,22 @@ def dump_custom_h5_type( try: write_h5_attr(entry, k, v) # type: ignore except H5StoreException: - msg = ( - f"Failed to write attribute {k} to group {group}!" - ) + msg = f"Failed to write attribute {k} to group {group}!" loguru.logger.error(msg, exc_info=True) return existing_groups - raise H5StoreException( - f"None of the custom HDF5 writer functions supported a value of type {type(value)}!" - ) + msg = f"None of the custom HDF5 writer functions supported a value of type {type(value)}!" + raise H5StoreException(msg) -def load_custom_h5_type( - file: str, group: str, entry_type: H5StoreTypes -) -> Any: +def load_custom_h5_type(file: Path, group: str, entry_type: H5StoreTypes) -> Any: loaders = {} loaders.update(CUSTOM_H5_LOADERS) loaders.update(DEFAULT_H5_LOADERS) for loader_type, loader in loaders.values(): if loader_type == entry_type: return loader(file, group) - msg = ( - f"No loader found for group {group} with HDF5 store type {entry_type}!" - ) + msg = f"No loader found for group {group} with HDF5 store type {entry_type}!" loguru.logger.error(msg, exc_info=True) raise H5StoreException(msg) @@ -264,21 +255,21 @@ def key_to_group_name(key): #### NONE #### @_register(DEFAULT_H5_WRITERS, lambda x: x is None) def dump_None( - file: str, group: str, value: None, existing_groups: Dict[str, Any] -) -> Tuple[Dict[str, Any], Dict[str, Any]]: + file: Path, group: str, value: None, existing_groups: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: attributes = {} attributes["type"] = H5StoreTypes.Null.name return attributes, existing_groups @_register(DEFAULT_H5_LOADERS, H5StoreTypes.Null) -def load_None(file: str, group: str): +def load_None(file: Path, group: str): return None #### FunctionCache #### @_register(DEFAULT_H5_LOADERS, H5StoreTypes.FunctionCache) -def load_hdf5_function_cache(file: str, group: str) -> Dict[_HashedSeq, Any]: +def load_hdf5_function_cache(file: Path, group: str) -> dict[_HashedSeq, Any]: function_calls = {} with h5py.File(file, mode="r", track_order=True) as hdf5_file: function_group = get_group(hdf5_file, group) @@ -287,12 +278,8 @@ def load_hdf5_function_cache(file: str, group: str) -> Dict[_HashedSeq, Any]: kwarg_group_name = f"{group}/{function_call}/kwargs" value_group_name = f"{group}/{function_call}/value" _, args = ExportableClassMixin._load_h5_entry(file, arg_group_name) - _, kwargs = ExportableClassMixin._load_h5_entry( - file, kwarg_group_name - ) - _, value = ExportableClassMixin._load_h5_entry( - file, value_group_name - ) + _, kwargs = ExportableClassMixin._load_h5_entry(file, kwarg_group_name) + _, value = ExportableClassMixin._load_h5_entry(file, value_group_name) key = _make_key(tuple(args), dict(sorted(kwargs.items()))) function_calls[key] = value @@ -302,11 +289,11 @@ def load_hdf5_function_cache(file: str, group: str) -> Dict[_HashedSeq, Any]: #### HDF5 Group ###### @_register(DEFAULT_H5_WRITERS, lambda x: isinstance(x, h5py.Group)) def dump_hdf5_group( - file: str, + file: Path, group: str, value: h5py.Group, - existing_groups: Dict[str, Any], -) -> Tuple[Dict[str, Any], Dict[str, Any]]: + existing_groups: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: attributes = {} with h5py.File(file, mode="a", track_order=True) as hdf5_file: new_entry = hdf5_file.require_group(group) @@ -319,7 +306,7 @@ def dump_hdf5_group( @_register(DEFAULT_H5_LOADERS, H5StoreTypes.HDF5Group) -def load_hdf5_group(file: str, group: str): +def load_hdf5_group(file: Path, group: str): tf = tempfile.TemporaryFile() with h5py.File(file, "r", track_order=True) as hdf5_file: f = h5py.File(tf, "w") @@ -329,15 +316,13 @@ def load_hdf5_group(file: str, group: str): #### python class #### -@_register( - DEFAULT_H5_WRITERS, lambda x: issubclass(type(x), ExportableClassMixin) -) +@_register(DEFAULT_H5_WRITERS, lambda x: issubclass(type(x), ExportableClassMixin)) def dump_exportable_class( - file: str, + file: Path, group: str, value: ExportableClassMixin, - existing_groups: Dict[str, Any], -) -> Tuple[Dict[str, Any], Dict[str, Any]]: + existing_groups: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: attributes = {} existing_groups = value._write_hdf5_contents( file=file, @@ -348,7 +333,7 @@ def dump_exportable_class( @_register(DEFAULT_H5_LOADERS, H5StoreTypes.PythonClass) -def load_exportable_class(file: str, group: str) -> ExportableClassMixin: +def load_exportable_class(file: Path, group: str) -> ExportableClassMixin: with h5py.File(file, "r", track_order=True) as hdf5_file: entry = get_group(hdf5_file, group) if "module" not in entry.attrs: @@ -367,8 +352,8 @@ def load_exportable_class(file: str, group: str) -> ExportableClassMixin: #### python enum #### @_register(DEFAULT_H5_WRITERS, lambda x: isinstance(x, Enum)) def dump_python_enum( - file: str, group: str, value: Type[Enum], existing_groups: Dict[str, Any] -) -> Tuple[Dict[str, Any], Dict[str, Any]]: + file: Path, group: str, value: type[Enum], existing_groups: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: attributes = {} with h5py.File(file, mode="a", track_order=True) as hdf5_file: new_entry = hdf5_file.require_group(group) @@ -382,24 +367,20 @@ def dump_python_enum( @_register(DEFAULT_H5_LOADERS, H5StoreTypes.Enumerator) -def load_python_enum(file: str, group: str) -> Type[Enum]: +def load_python_enum(file: Path, group: str) -> type[Enum]: with h5py.File(file, "r", track_order=True) as hdf5_file: entry = get_group(hdf5_file, group) try: enum_class_ = get_dataset(entry, "enum_class")[()].decode("utf-8") enum_value_ = get_dataset(entry, "enum_value")[()].decode("utf-8") - enum_module_ = get_dataset(entry, "enum_module")[()].decode( - "utf-8" - ) + enum_module_ = get_dataset(entry, "enum_module")[()].decode("utf-8") except KeyError as exc: msg = f"Failed to load {group} because it was a enum entry, but didnt have the enum name or value!" loguru.logger.error(msg) raise H5StoreException(msg) from exc except Exception as exc: error = traceback.format_exc() - msg = ( - f"Failed to load enum from {group} because of error '{error}'!" - ) + msg = f"Failed to load enum from {group} because of error '{error}'!" loguru.logger.error(msg) raise H5StoreException(msg) from exc enum_module = importlib.import_module(enum_module_) @@ -419,21 +400,20 @@ def load_python_enum(file: str, group: str) -> Type[Enum]: #### generic HDF5 dataset #### @_register( DEFAULT_H5_WRITERS, - lambda x: isinstance( - x, (PYTHON_BASIC_TYPES, NUMPY_NUMERIC_TYPES, np.ndarray) - ) + lambda x: isinstance(x, (PYTHON_BASIC_TYPES, NUMPY_NUMERIC_TYPES, np.ndarray)) # noqa: UP038 or ( isinstance(x, Iterable) and (not isinstance(x, dict)) and all( - isinstance(v, ALL_VALID_DATASET_TYPES) for v in x # type: ignore + isinstance(v, ALL_VALID_DATASET_TYPES) # type: ignore + for v in x # type: ignore ) and all(isinstance(v, type(x[0])) for v in x) # type: ignore ), ) def dump_python_types_or_ndarray( - file: str, group: str, value: Any, existing_groups: Dict[str, Any] -) -> Tuple[Dict[str, Any], Dict[str, Any]]: + file: Path, group: str, value: Any, existing_groups: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: attributes = {} with h5py.File(file, mode="a", track_order=True) as hdf5_file: try: @@ -450,7 +430,7 @@ def dump_python_types_or_ndarray( @_register(DEFAULT_H5_LOADERS, H5StoreTypes.HDF5Dataset) -def load_python_types_or_ndarray(file: str, group: str): +def load_python_types_or_ndarray(file: Path, group: str): with h5py.File(file, "r", track_order=True) as hdf5_file: entry = get_dataset(hdf5_file, group) try: @@ -464,20 +444,18 @@ def load_python_types_or_ndarray(file: str, group: str): #### class constructor #### def dump_class_constructor( - file: str, + file: Path, group: str, value: ExportableClassMixin, - existing_groups: Dict[str, Any], -) -> Dict[str, Any]: + existing_groups: dict[str, Any], +) -> dict[str, Any]: with h5py.File(file, mode="a", track_order=True) as hdf5_file: my_group = hdf5_file.require_group(group) constructor_group = my_group.require_group(value.__class__.__name__) constructor_attributes = {} constructor_attributes["module"] = value.__class__.__module__ constructor_attributes["class"] = value.__class__.__name__ - constructor_attributes[ - "timestamp" - ] = datetime.datetime.now().isoformat() + constructor_attributes["timestamp"] = datetime.datetime.now().isoformat() constructor_attributes["type"] = H5StoreTypes.ClassConstructor.name # Write out the attributes for k, v in constructor_attributes.items(): @@ -498,14 +476,12 @@ def dump_class_constructor( @_register(DEFAULT_H5_LOADERS, H5StoreTypes.ClassConstructor) -def load_class_constructor(file: str, group: str): +def load_class_constructor(file: Path, group: str): kwargs = {} with h5py.File(file, mode="r", track_order=True) as hdf5_file: my_group = get_group(hdf5_file, group) for key in my_group: - kwargs[key] = ExportableClassMixin._load_h5_entry( - file, f"{group}/{key}" - )[1] + kwargs[key] = ExportableClassMixin._load_h5_entry(file, f"{group}/{key}")[1] return kwargs @@ -513,11 +489,11 @@ def load_class_constructor(file: str, group: str): #### python dict #### @_register(DEFAULT_H5_WRITERS, lambda x: isinstance(x, dict)) def dump_dictionary( - file: str, + file: Path, group: str, - value: Dict[Any, Any], - existing_groups: Dict[str, Any], -) -> Tuple[Dict[str, Any], Dict[str, Any]]: + value: dict[Any, Any], + existing_groups: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: attributes = {} with h5py.File(file, mode="a", track_order=True) as hdf5_file: _ = hdf5_file.require_group(group) @@ -571,7 +547,7 @@ def dump_dictionary( @_register(DEFAULT_H5_LOADERS, H5StoreTypes.Dictionary) -def load_dictionary(file: str, group: str): +def load_dictionary(file: Path, group: str): with h5py.File(file, mode="r", track_order=True) as hdf5_file: entry = get_group(hdf5_file, group) @@ -596,11 +572,11 @@ def load_dictionary(file: str, group: str): #### python sequence #### @_register(DEFAULT_H5_WRITERS, lambda x: isinstance(x, Iterable)) def dump_generic_sequence( - file: str, + file: Path, group: str, value: Sequence[Any], - existing_groups: Dict[str, Any], -) -> Tuple[Dict[str, Any], Dict[str, Any]]: + existing_groups: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: basename = group.split("/")[-1] attributes = {} with h5py.File(file, mode="a", track_order=True) as hdf5_file: @@ -631,14 +607,12 @@ def dump_generic_sequence( @_register(DEFAULT_H5_LOADERS, H5StoreTypes.Sequence) -def load_generic_sequence(file: str, group: str) -> List[Any]: +def load_generic_sequence(file: Path, group: str) -> list[Any]: with h5py.File(file, mode="r", track_order=True) as hdf5_file: entry = get_group(hdf5_file, group) basename = group.split("/")[-1] items = list(entry) - malformed_entries = [ - k for k in items if not re.match(rf"{basename}_[0-9]+", k) - ] + malformed_entries = [k for k in items if not re.match(rf"{basename}_[0-9]+", k)] if len(malformed_entries) > 0: msg = f"Failed to load {group} because it was a dictionary, but contained non-item groups {','.join(malformed_entries)}!" loguru.logger.error(msg) @@ -658,11 +632,11 @@ def load_generic_sequence(file: str, group: str) -> List[Any]: #### xarray #### @register_writer(lambda x: isinstance(x, xr.Dataset)) def dump_xarray( - file: str, + file: Path, group: str, value: Dataset, - existing_groups: Dict[str, Any], - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + existing_groups: dict[str, Any], + ) -> tuple[dict[str, Any], dict[str, Any]]: attributes = {} compression_args = {} @@ -686,23 +660,21 @@ def dump_xarray( return attributes, existing_groups @register_loader(H5StoreTypes.XArrayDataset) - def load_xarray(file: str, group: str) -> Dataset: - return xr.load_dataset( - file, group=group, format="NETCDF4", engine="h5netcdf" - ) + def load_xarray(file: Path, group: str) -> Dataset: + return xr.load_dataset(file, group=group, format="NETCDF4", engine="h5netcdf") if pt is not None: @register_writer( - lambda x: isinstance(x, (pt.core.Isotope, pt.core.Element)), + lambda x: isinstance(x, pt.core.Isotope | pt.core.Element), ) def dump_periodictable_element( - file: str, + file: Path, group: str, - value: Union[Isotope, Element], - existing_groups: Dict[str, Any], - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + value: Isotope | Element, + existing_groups: dict[str, Any], + ) -> tuple[dict[str, Any], dict[str, Any]]: attributes = {} with h5py.File(file, mode="a", track_order=True) as hdf5_file: new_entry = hdf5_file.require_group(group) @@ -715,31 +687,27 @@ def dump_periodictable_element( return attributes, existing_groups @register_loader(H5StoreTypes.PeriodicTableElement) - def load_periodictable_element( - file: str, group: str - ) -> Union[Isotope, Element]: + def load_periodictable_element(file: Path, group: str) -> Isotope | Element: with h5py.File(file, "r", track_order=True) as hdf5_file: entry = get_group(hdf5_file, group) atomic_weight = float(get_dataset(entry, "element_A")[()]) atomic_number = float(get_dataset(entry, "element_Z")[()]) - return get_element_from_number_and_weight( - z=atomic_number, a=atomic_weight - ) + return get_element_from_number_and_weight(z=atomic_number, a=atomic_weight) if az is not None: #### arviz #### @register_writer(lambda x: isinstance(x, az.InferenceData)) def dump_inferencedata( - file: str, + file: Path, group: str, value: InferenceData, - existing_groups: Dict[str, Any], - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + existing_groups: dict[str, Any], + ) -> tuple[dict[str, Any], dict[str, Any]]: attributes = {} value.to_netcdf( - file, + str(file), engine="h5netcdf", base_group=group, overwrite_existing=False, @@ -752,10 +720,8 @@ def dump_inferencedata( return attributes, existing_groups @register_loader(H5StoreTypes.ArViz) - def load_inferencedata(file: str, group: str) -> InferenceData: - return az.InferenceData.from_netcdf( - file, base_group=group, engine="h5netcdf" - ) + def load_inferencedata(file: Path, group: str) -> InferenceData: + return az.InferenceData.from_netcdf(file, base_group=group, engine="h5netcdf") #### ndarray with units/error #### @@ -763,8 +729,8 @@ def load_inferencedata(file: str, group: str) -> InferenceData: lambda x: isinstance(x, Uncertainty) or hasattr(x, "units"), ) def dump_unit_or_error_ndarrays( - file: str, group: str, value: Any, existing_groups: Dict[str, Any] -) -> Tuple[Dict[str, Any], Dict[str, Any]]: + file: Path, group: str, value: Any, existing_groups: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: if hasattr(value, "units"): u = str(value.units) v = value.magnitude @@ -805,7 +771,7 @@ def dump_unit_or_error_ndarrays( @register_loader(H5StoreTypes.DimensionalNDArray) -def load_unit_or_error_ndarrays(file: str, group: str): +def load_unit_or_error_ndarrays(file: Path, group: str): with h5py.File(file, "r", track_order=True) as hdf5_file: entry = get_group(hdf5_file, group) v = get_dataset(entry, "value")[()] @@ -816,16 +782,14 @@ def load_unit_or_error_ndarrays(file: str, group: str): if np.any(e > 0): v = Uncertainty(v, e) else: - raise ImportError( - f"Need auto_uncertainties to load uncertainty arrays from group {entry}!" - ) + msg = f"Need auto_uncertainties to load uncertainty arrays from group {entry}!" + raise ImportError(msg) try: unit = read_h5_attr(entry, "unit") if unit is not None: if ur is None: - raise ValueError( - "The dataset had unit information, but pint is not installed!" - ) + msg = "The dataset had unit information, but pint is not installed!" + raise ValueError(msg) else: v *= ur(unit) except KeyError: @@ -840,32 +804,32 @@ def load_unit_or_error_ndarrays(file: str, group: str): sz = np.size(v) if sz > 1: raise H5StoreException - try: + with contextlib.suppress(IndexError, TypeError): v = v[0] - except (IndexError, TypeError): - pass return v def get_group(hdf5_handle: h5py.File | h5py.Group, name: str) -> h5py.Group: if name not in hdf5_handle: - raise KeyError(f"Could not find group {name}!") + msg = f"Could not find group {name}!" + raise KeyError(msg) else: ret = hdf5_handle[name] if isinstance(ret, h5py.Group): return ret else: - raise TypeError(f"Expected group {name} but found {type(ret)}!") + msg = f"Expected group {name} but found {type(ret)}!" + raise TypeError(msg) -def get_dataset( - hdf5_handle: h5py.File | h5py.Group, name: str -) -> h5py.Dataset: +def get_dataset(hdf5_handle: h5py.File | h5py.Group, name: str) -> h5py.Dataset: if name not in hdf5_handle: - raise KeyError(f"Could not find group {name}!") + msg = f"Could not find group {name}!" + raise KeyError(msg) else: ret = hdf5_handle[name] if isinstance(ret, h5py.Dataset): return ret else: - raise TypeError(f"Expected group {name} but found {type(ret)}!") + msg = f"Expected group {name} but found {type(ret)}!" + raise TypeError(msg) diff --git a/prepper/tests/__init__.py b/prepper/tests/__init__.py index e69de29..a914edc 100644 --- a/prepper/tests/__init__.py +++ b/prepper/tests/__init__.py @@ -0,0 +1,27 @@ +import lazy_loader + + +__getattr__, __dir__, __all__ = lazy_loader.attach_stub(__name__, __file__) + +__all__ = [ + "NATIVE_DTYPES", + "NotASaveableClass", + "SimpleSaveableClass", + "SimpleSaveableClass2", + "ensure_required_attributes", + "req_class_attrs", + "req_dataset_attrs", + "req_none_attrs", + "roundtrip", + "test_IO", + "test_cached_property", + "test_decorators", + "test_local_cache", + "test_saveable_class", + "test_with_float_list", + "test_with_floats", + "test_with_heterogenous_list", + "test_with_int_list", + "test_with_ints", + "test_with_str_list", +] diff --git a/prepper/tests/__init__.pyi b/prepper/tests/__init__.pyi new file mode 100644 index 0000000..42c28b5 --- /dev/null +++ b/prepper/tests/__init__.pyi @@ -0,0 +1,50 @@ +from . import test_IO +from . import test_decorators + +from .test_IO import ( + NATIVE_DTYPES, + SimpleSaveableClass, + ensure_required_attributes, + req_class_attrs, + req_dataset_attrs, + req_none_attrs, + roundtrip, + test_cached_property, + test_with_float_list, + test_with_floats, + test_with_heterogenous_list, + test_with_int_list, + test_with_ints, + test_with_str_list, +) +from .test_decorators import ( + NotASaveableClass, + SimpleSaveableClass, + SimpleSaveableClass2, + test_cached_property, + test_local_cache, + test_saveable_class, +) + +__all__ = [ + "NATIVE_DTYPES", + "NotASaveableClass", + "SimpleSaveableClass", + "SimpleSaveableClass2", + "ensure_required_attributes", + "req_class_attrs", + "req_dataset_attrs", + "req_none_attrs", + "roundtrip", + "test_IO", + "test_cached_property", + "test_decorators", + "test_local_cache", + "test_saveable_class", + "test_with_float_list", + "test_with_floats", + "test_with_heterogenous_list", + "test_with_int_list", + "test_with_ints", + "test_with_str_list", +] diff --git a/prepper/tests/test_IO.py b/prepper/tests/test_IO.py index 2cdb2ce..19c4237 100644 --- a/prepper/tests/test_IO.py +++ b/prepper/tests/test_IO.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import inspect @@ -10,8 +9,8 @@ from hypothesis.extra import numpy as hnp from prepper import ( - cached_property, ExportableClassMixin, + cached_property, local_cache, saveable_class, ) @@ -68,9 +67,7 @@ def roundtrip(obj: ExportableClassMixin, should_not_be_saved=None): @given( hnp.arrays( - elements=strategies.floats( - allow_nan=False, allow_infinity=False, width=32 - ), + elements=strategies.floats(allow_nan=False, allow_infinity=False, width=32), shape=hnp.array_shapes(min_dims=1, max_dims=4), dtype=float, ) @@ -124,9 +121,7 @@ def test_with_str_list(x): @given( hnp.arrays( - elements=strategies.floats( - allow_nan=False, allow_infinity=False, width=32 - ), + elements=strategies.floats(allow_nan=False, allow_infinity=False, width=32), shape=hnp.array_shapes(min_dims=1, max_dims=4), dtype=float, ) diff --git a/prepper/tests/test_decorators.py b/prepper/tests/test_decorators.py index 2aeaae0..3ad3bbc 100644 --- a/prepper/tests/test_decorators.py +++ b/prepper/tests/test_decorators.py @@ -1,13 +1,12 @@ -# -*- coding: utf-8 -*- from __future__ import annotations -import pytest from hypothesis import given, strategies from hypothesis.extra import numpy as hnp +import pytest from prepper import ( - cached_property, ExportableClassMixin, + cached_property, local_cache, saveable_class, ) @@ -79,12 +78,10 @@ def test_saveable_class(): functions=["square"], ), ] - pytest.raises(ValueError, decorator, NotASaveableClass) + pytest.raises(TypeError, decorator, NotASaveableClass) decorated = decorator(SimpleSaveableClass) assert decorated._exportable_functions == ["SimpleSaveableClass.square"] - assert decorated._exportable_attributes == [ - "SimpleSaveableClass.test_string" - ] + assert decorated._exportable_attributes == ["SimpleSaveableClass.test_string"] for d in bad_decorators: pytest.raises(ValueError, d, SimpleSaveableClass) @@ -94,7 +91,7 @@ def test_cached_property(): test_class = SimpleSaveableClass2() # Make sure __dict__ doesn't have the cached property - assert not any("test_string" in k for k in test_class.__dict__.keys()) + assert not any("test_string" in k for k in test_class.__dict__) # Make sure the cached property works correctly with the super() call assert ( @@ -115,9 +112,7 @@ def test_cached_property(): @given( hnp.arrays( - elements=strategies.floats( - allow_nan=False, allow_infinity=False, width=32 - ), + elements=strategies.floats(allow_nan=False, allow_infinity=False, width=32), shape=(10,), dtype=float, ) @@ -131,8 +126,7 @@ def test_local_cache(x): # Make sure the cache is stores the parent and child calls assert ( - test_class2.__dict__["__cache_SimpleSaveableClass.square__"][key] - == x_**2 + test_class2.__dict__["__cache_SimpleSaveableClass.square__"][key] == x_**2 ) assert ( test_class2.__dict__["__cache_SimpleSaveableClass2.square__"][key] diff --git a/prepper/utils.py b/prepper/utils.py index 5206dce..d1d64b6 100644 --- a/prepper/utils.py +++ b/prepper/utils.py @@ -1,7 +1,6 @@ -# -*- coding: utf-8 -*- from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any import loguru import numpy as np @@ -19,8 +18,10 @@ if TYPE_CHECKING: from periodictable.core import Element, Isotope +__all__ = ["check_equality", "get_element_from_number_and_weight"] -def check_equality(value1: Any, value2: Any, log: bool = False) -> bool: + +def check_equality(value1: Any, value2: Any, *, log: bool = False) -> bool: """ Check if two objects are equal """ @@ -39,12 +40,9 @@ def check_equality(value1: Any, value2: Any, log: bool = False) -> bool: same = bool(same) loguru.logger.enable("prepper") - if not same: - if log: - loguru.logger.debug( - f"Values are different: {value1} and {value2}" - ) - return same + if not same and log: + loguru.logger.debug(f"Values are different: {value1} and {value2}") + except Exception: # Maybe it's a numpy array # check if the dimensions are compatible @@ -55,9 +53,7 @@ def check_equality(value1: Any, value2: Any, log: bool = False) -> bool: f"Dims are different: {np.ndim(value1)} and {np.ndim(value2)} for values {value1} and {value2}" ) return False - if hasattr(value1, "units") and not value1.is_compatible_with( - value2 - ): + if hasattr(value1, "units") and not value1.is_compatible_with(value2): if log: loguru.logger.debug( f"Units are different: {getattr(value1,'units','')} and {getattr(value2,'units','')}" @@ -68,12 +64,10 @@ def check_equality(value1: Any, value2: Any, log: bool = False) -> bool: same = np.allclose(value1, value2) except Exception: same = all(value1 == value2) - if not same: - if log: - loguru.logger.debug( - f"Numpy check: values are different: {value1} and {value2}" - ) - return same + if not same and log: + loguru.logger.debug( + f"Numpy check: values are different: {value1} and {value2}" + ) except Exception as e: if not isinstance(value1, type(value2)): if log: @@ -82,14 +76,15 @@ def check_equality(value1: Any, value2: Any, log: bool = False) -> bool: ) return False - raise ValueError( - f"Cannot compare {value1} and {value2} of type {type(value1)} and {type(value2)}" - ) from e + msg = f"Cannot compare {value1} and {value2} of type {type(value1)} and {type(value2)}" + raise ValueError(msg) from e + else: + return same + else: + return same -def get_element_from_number_and_weight( - z: float, a: float -) -> Isotope | Element: +def get_element_from_number_and_weight(z: float, a: float) -> Isotope | Element: """ This function takes in a float value 'z' representing the atomic number, and another float value 'a' representing the atomic mass, and returns @@ -103,7 +98,8 @@ def get_element_from_number_and_weight( # Iterates over elements in the periodic table to # find the element that matches the atomic number and weight. if elements is None: - raise ImportError("Could not import periodictable") + msg = "Could not import periodictable" + raise ImportError(msg) for element in elements: for iso in element: @@ -117,9 +113,8 @@ def get_element_from_number_and_weight( elm = iso # If we didnt find an element, raise a ValueError exception if elm is None: - raise ValueError( - f"Could not find a matching element for A = {a} and Z = {z}" - ) + msg = f"Could not find a matching element for A = {a} and Z = {z}" + raise ValueError(msg) # If the found element is the base element, just return the base element if np.abs(elm.element.mass - elm.mass) < 0.3: diff --git a/pyproject.toml b/pyproject.toml index 50468e8..6e89fbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,11 +18,12 @@ dependencies = [ "pandas >= 1.5.1", "joblib >= 1.2.0", "typing_extensions >= 4.6.3", + "lazy_loader", ] dynamic = ["version"] [project.optional-dependencies] -CI = ["pytest", "pytest-cov","hypothesis","pylint"] +CI = ["pytest", "pytest-cov","hypothesis","pylint", "xarray", "auto_uncertainties", "periodictable", "arviz", "pint"] [build-system] requires = ["setuptools >= 67.0.0", "setuptools_scm[toml]>=6.2"] @@ -34,61 +35,6 @@ write_to = "prepper/_version.py" [tool.setuptools.packages.find] exclude = ["typings*"] -[tool.ruff] -line-length = 79 -target-version = "py311" -ignore = ["E501"] -exclude = ["**/*.pyi"] -[tool.ruff.isort] -force-sort-within-sections = true -required-imports = ["from __future__ import annotations"] - -[tool.black] -line-length = 79 -include = '\.pyi?$' -exclude = ''' -/( - \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist -)/ -''' - - -[tool.isort] -profile = "black" -line_length = 79 -force_alphabetical_sort_within_sections = true -add_imports = ["from __future__ import annotations"] - -[tool.pylint] -#Enforced by black -max-line-length = 1000 - -disable = """ -line-too-long, -missing-module-docstring, -missing-function-docstring, -broad-exception-caught, -too-many-branches, -invalid-name, -protected-access, -import-outside-toplevel, -wrong-import-position, -missing-class-docstring, -too-many-locals, -redefined-builtin, -too-few-public-methods, -global-statement -""" -ignore = "tests" - [tool.pyright] pythonVersion = "3.11" pythonPlatform = "Linux" diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..6071806 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,80 @@ + +target-version = "py311" + +[lint] +select = [ + "I", + "W291", + "W292", + "T100", + "YTT", + "UP009", + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + "FBT", + "COM", + "C4", + "DTZ", + "EM", + "EXE", + "FA", + "ISC", + "ICN", + "LOG", + "G", + "INP", + "PIE", + "T20", + "PYI", + "PT", + "Q", + "RSE", + "RET", + "SLF", + "TID", + "INT", + "PTH", + "TD", + "PD", + "TRY", + "FLY", + "NPY", + "PERF", + "RUF", + +] + +ignore = ["E501", "ISC002", "ISC001", +"COM819", "COM812", "Q003", "Q002", "Q001", "Q000", +"D300", "D206", "E117", "E114", "E111", "W191", +"B008", "SIM300", "S101", "RET505", "SLF001", "DTZ005", "RET506" +] + +exclude = ["typings/*", "prepper/tests/*"] + +[lint.per-file-ignores] +"__init__.py" = ["F401", "E402", "I001", "I002"] +"__init__.pyi" = ["F401", "E402", "I001", "I002"] + +[format] +docstring-code-format = true +line-ending = "lf" + +[lint.isort] +from-first = false +force-sort-within-sections = true +required-imports = ["from __future__ import annotations"] + +[lint.flake8-bugbear] +extend-immutable-calls = ["pylatex.utils.NoEscape"] + +[lint.flake8-self] +ignore-names = [] diff --git a/typings/periodictable/__init__.pyi b/typings/periodictable/__init__.pyi index 2b9ccda..340e88a 100644 --- a/typings/periodictable/__init__.pyi +++ b/typings/periodictable/__init__.pyi @@ -49,7 +49,6 @@ def data_files(): # -> list[Unknown]: used directly in setup(..., data_files=...) for setup.py. """ - ... __all__ += core.define_elements(elements, globals()) @@ -98,7 +97,6 @@ def formula(*args, **kw): be used as a basis for a rich text representation such as matplotlib TeX markup. """ - ... def mix_by_weight(*args, **kw): # -> Formula: """ @@ -136,7 +134,6 @@ def mix_by_weight(*args, **kw): # -> Formula: If density is not given, then it will be computed from the density of the components, assuming equal volume. """ - ... def mix_by_volume(*args, **kw): # -> Formula: """ @@ -176,7 +173,6 @@ def mix_by_volume(*args, **kw): # -> Formula: assuming the components take up no more nor less space because they are in the mixture. """ - ... def neutron_sld( *args, **kw @@ -188,7 +184,6 @@ def neutron_sld( See :class:`periodictable.nsf.neutron_sld` for details. """ - ... def neutron_scattering( *args, **kw @@ -202,7 +197,6 @@ def neutron_scattering( See :func:`periodictable.nsf.neutron_scattering` for details. """ - ... def xray_sld( *args, **kw @@ -216,4 +210,3 @@ def xray_sld( See :class:`periodictable.xsf.Xray` for details. """ - ... diff --git a/typings/periodictable/activation.pyi b/typings/periodictable/activation.pyi index 87bb90e..0c7418f 100644 --- a/typings/periodictable/activation.pyi +++ b/typings/periodictable/activation.pyi @@ -84,7 +84,6 @@ def NIST2001_isotopic_abundance(iso): Isotopic Compositions of the Elements, 2001. J. Phys. Chem. Ref. Data, Vol. 34, No. 1, 2005 """ - ... def IAEA1987_isotopic_abundance(iso): # -> Literal[0]: """ @@ -96,7 +95,6 @@ def IAEA1987_isotopic_abundance(iso): # -> Literal[0]: IAEA 273: Handbook on Nuclear Activation Data, 1987. """ - ... class Sample: """ @@ -134,13 +132,11 @@ class Sample: default it uses :func:`NIST2001_isotopic_abundance`, and there is the alternative :func:`IAEA1987_isotopic_abundance`. """ - ... def decay_time(self, target): # -> Literal[0]: """ After determining the activation, compute the number of hours required to achieve a total activation level after decay. """ - ... def show_table(self, cutoff=..., format=...): # -> None: """ Tabulate the daughter products. @@ -153,7 +149,6 @@ class Sample: The number format to use for the activation. """ - ... def find_root(x, f, df, max=..., tol=...): # -> tuple[Unknown, Unknown]: r""" @@ -164,11 +159,9 @@ def find_root(x, f, df, max=..., tol=...): # -> tuple[Unknown, Unknown]: Returns x, f(x). """ - ... def sorted_activity(activity_pair): # -> list[Unknown]: """Interator over activity pairs sorted by isotope then daughter product.""" - ... class ActivationEnvironment: """ @@ -200,26 +193,21 @@ class ActivationEnvironment: Used as a multiplier times the resonance cross section to add to the thermal cross section for all thermal induced reactions. """ - ... COLUMN_NAMES = ... INT_COLUMNS = ... BOOL_COLUMNS = ... FLOAT_COLUMNS = ... -def activity( - isotope, mass, env, exposure, rest_times -): # -> dict[Unknown, Unknown]: +def activity(isotope, mass, env, exposure, rest_times): # -> dict[Unknown, Unknown]: """ Compute isotope specific daughter products after the given exposure time and rest period. """ - ... def init(table, reload=...): # -> None: """ Add neutron activation levels to each isotope. """ - ... class ActivationResult: def __init__(self, **kw) -> None: ... diff --git a/typings/periodictable/core.pyi b/typings/periodictable/core.pyi index b7d0376..ee30de0 100644 --- a/typings/periodictable/core.pyi +++ b/typings/periodictable/core.pyi @@ -59,7 +59,7 @@ Helper functions: """ from __future__ import annotations -from typing import Dict, Iterator, List +from collections.abc import Iterator from pint._typing import QuantityOrUnitLike as Quantity @@ -81,9 +81,7 @@ __all__ = [ ] PUBLIC_TABLE_NAME = ... -def delayed_load( - all_props, loader, element=..., isotope=..., ion=... -): # -> None: +def delayed_load(all_props, loader, element=..., isotope=..., ion=...): # -> None: """ Delayed loading of an element property table. When any of property is first accessed the loader will be called to load the associated @@ -96,7 +94,6 @@ def delayed_load( keyword flags *element*, *isotope* and/or *ion* to specify which of these classes will be assigned specific information on load. """ - ... class PeriodicTable: """ @@ -114,11 +111,11 @@ class PeriodicTable: Fe >>> print(elements.Fe) Fe - >>> print(elements.symbol('Fe')) + >>> print(elements.symbol("Fe")) Fe - >>> print(elements.name('iron')) + >>> print(elements.name("iron")) Fe - >>> print(elements.isotope('Fe')) + >>> print(elements.isotope("Fe")) Fe @@ -130,7 +127,7 @@ class PeriodicTable: 56-Fe >>> print(elements.Fe[56]) 56-Fe - >>> print(elements.isotope('56-Fe')) + >>> print(elements.isotope("56-Fe")) 56-Fe @@ -143,7 +140,9 @@ class PeriodicTable: >>> from periodictable import * >>> for el in elements: # lists the element symbols - ... print("%s %s"%(el.symbol, el.name)) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + ... print( + ... "%s %s" % (el.symbol, el.name) + ... ) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE n neutron H hydrogen He helium @@ -157,8 +156,8 @@ class PeriodicTable: See section :ref:`Adding properties ` for details. """ - properties: List[str] - _elements: Dict[int, Element] + properties: list[str] + _elements: dict[int, Element] D: Isotope T: Isotope n: Element @@ -285,14 +284,12 @@ class PeriodicTable: """ Retrieve element Z. """ - ... def __iter__( self, ) -> Iterator[Element]: # -> Generator[Unknown, Any, None]: """ Process the elements in Z order """ - ... def symbol(self, input: str) -> Element | Isotope: # -> Element | Isotope: """ Lookup the an element in the periodic table using its symbol. Symbols @@ -312,10 +309,9 @@ class PeriodicTable: .. doctest:: >>> import periodictable - >>> print(periodictable.elements.symbol('Fe')) + >>> print(periodictable.elements.symbol("Fe")) Fe """ - ... def name(self, input: str) -> Element | Isotope: # -> Element | Isotope """ Lookup an element given its name. @@ -334,13 +330,10 @@ class PeriodicTable: .. doctest:: >>> import periodictable - >>> print(periodictable.elements.name('iron')) + >>> print(periodictable.elements.name("iron")) Fe """ - ... - def isotope( - self, input: str - ) -> Element | Isotope: # -> Element | Isotope: + def isotope(self, input: str) -> Element | Isotope: # -> Element | Isotope: """ Lookup the element or isotope in the periodic table. Elements are assumed to be given by the standard element symbols. Isotopes @@ -360,10 +353,9 @@ class PeriodicTable: .. doctest:: >>> import periodictable - >>> print(periodictable.elements.isotope('58-Ni')) + >>> print(periodictable.elements.isotope("58-Ni")) 58-Ni """ - ... def list(self, *props, **kw): # -> None: """ Print a list of elements with the given set of properties. @@ -390,7 +382,6 @@ class PeriodicTable: ... Bk: 247.00 u 14.00 g/cm^3 """ - ... class IonSet: def __init__(self, element_or_isotope) -> None: ... @@ -410,7 +401,6 @@ class Ion: def __getattr__(self, attr): ... @property def mass(self): ... - def __str__(self) -> str: ... def __repr__(self): ... def __reduce__(self): ... @@ -443,7 +433,6 @@ class Isotope: def __init__(self, element, isotope_number) -> None: ... def __getattr__(self, attr): ... - def __str__(self) -> str: ... def __repr__(self): ... def __reduce__(self): ... @@ -466,7 +455,7 @@ class Element: density: float mass_units: str density_units: str - _isotopes: List[Isotope] + _isotopes: list[Isotope] # added in lotus base_symbol: str | None @@ -474,13 +463,10 @@ class Element: udensity: Quantity | None isotope_id: int | None - def __init__( - self, name: str, symbol: str, Z: int, ions, table - ) -> None: ... + def __init__(self, name: str, symbol: str, Z: int, ions, table) -> None: ... @property - def isotopes(self) -> List[int]: # -> list[Unknown]: + def isotopes(self) -> list[int]: # -> list[Unknown]: """List of all isotopes""" - ... def add_isotope(self, number): """ Add an isotope for the element. @@ -491,7 +477,6 @@ class Element: :Returns: None """ - ... def __getitem__(self, number: int) -> Isotope: ... def __iter__( self, @@ -499,29 +484,23 @@ class Element: """ Process the isotopes in order """ - ... def __repr__(self): ... def __reduce__(self): ... def isatom(val): # -> bool: """Return true if value is an element, isotope or ion""" - ... def isisotope(val): # -> bool: """Return true if value is an isotope or isotope ion.""" - ... def ision(val): # -> bool: """Return true if value is a specific ion of an element or isotope""" - ... def iselement(val): # -> bool: """Return true if value is an element or ion in natural abundance""" - ... def change_table(atom, table): """Search for the same element, isotope or ion from a different table""" - ... PRIVATE_TABLES = ... element_base = ... @@ -536,7 +515,6 @@ def default_table(table=...): # -> PeriodicTable: table = core.default_table(table) ... """ - ... def define_elements(table, namespace): # -> list[Unknown]: """ @@ -546,7 +524,7 @@ def define_elements(table, namespace): # -> list[Unknown]: This is called from *__init__* as:: elements = core.default_table() - __all__ += core.define_elements(elements, globals()) + __all__ += core.define_elements(elements, globals()) :Parameters: *table* : PeriodicTable @@ -558,7 +536,6 @@ def define_elements(table, namespace): # -> list[Unknown]: .. Note:: This will only work for *namespace* globals(), not locals()! """ - ... def get_data_path(data): # -> str: """ @@ -572,6 +549,5 @@ def get_data_path(data): # -> str: :Returns: string Path to the data. """ - ... PUBLIC_TABLE = ... diff --git a/typings/periodictable/covalent_radius.pyi b/typings/periodictable/covalent_radius.pyi index 6f7b42a..82bd725 100644 --- a/typings/periodictable/covalent_radius.pyi +++ b/typings/periodictable/covalent_radius.pyi @@ -71,7 +71,6 @@ def init(table, reload=...): # -> None: Use *reload = True* to replace the covalent radius property on an existing table. """ - ... Cordero = ... CorderoPyykko = ... diff --git a/typings/periodictable/cromermann.pyi b/typings/periodictable/cromermann.pyi index c4af07e..f0b8ba8 100644 --- a/typings/periodictable/cromermann.pyi +++ b/typings/periodictable/cromermann.pyi @@ -18,7 +18,6 @@ def getCMformula(symbol): Return instance of CromerMannFormula. """ - ... def fxrayatq(symbol, Q, charge=...): """ @@ -33,7 +32,6 @@ def fxrayatq(symbol, Q, charge=...): Return float or numpy array. """ - ... def fxrayatstol(symbol, stol, charge=...): """ @@ -48,7 +46,6 @@ def fxrayatstol(symbol, stol, charge=...): Return float or numpy.array. """ - ... class CromerMannFormula: """ @@ -79,7 +76,6 @@ class CromerMannFormula: No return value """ - ... def atstol(self, stol): # -> Any: """ Calculate x-ray scattering factors at specified sin(theta)/lambda @@ -89,6 +85,5 @@ class CromerMannFormula: Return float or numpy.array. """ - ... _cmformulas = ... diff --git a/typings/periodictable/crystal_structure.pyi b/typings/periodictable/crystal_structure.pyi index f330ead..ec8dfab 100644 --- a/typings/periodictable/crystal_structure.pyi +++ b/typings/periodictable/crystal_structure.pyi @@ -50,4 +50,3 @@ def init(table, reload=...): # -> None: """ Add crystal_structure field to the element properties. """ - ... diff --git a/typings/periodictable/density.pyi b/typings/periodictable/density.pyi index 0431be7..856ab4d 100644 --- a/typings/periodictable/density.pyi +++ b/typings/periodictable/density.pyi @@ -62,7 +62,6 @@ def density(iso_el): 80th ed. (1999).* """ - ... def interatomic_distance(element): # -> None: r""" @@ -94,7 +93,6 @@ def interatomic_distance(element): # -> None: (10^{-8} cm\cdot \AA^{-1})^3))^{1/3} = \AA """ - ... def number_density(element): # -> None: r""" @@ -123,7 +121,6 @@ def number_density(element): # -> None: = atoms\cdot cm^{-3} """ - ... def init(table, reload=...): ... diff --git a/typings/periodictable/fasta.pyi b/typings/periodictable/fasta.pyi index 0b885b0..6e49605 100644 --- a/typings/periodictable/fasta.pyi +++ b/typings/periodictable/fasta.pyi @@ -54,7 +54,6 @@ def isotope_substitution(formula, source, target, portion=...): .. deprecated:: 1.5.3 Use formula.replace(source, target, portion) instead. """ - ... class Molecule: """ @@ -117,7 +116,6 @@ class Molecule: Changed 1.5.3: fix errors in SLD calculations. """ - ... class Sequence(Molecule): """ @@ -143,13 +141,11 @@ class Sequence(Molecule): Yields one FASTA sequence each cycle. """ - ... @staticmethod def load(filename, type=...): # -> Sequence: """ Load the first FASTA sequence from a file. """ - ... def __init__(self, name, sequence, type=...) -> None: ... H2O_SLD = ... @@ -173,7 +169,6 @@ def D2Omatch(Hsld, Dsld): Change 1.5.3: corrected D2O sld, which will change the computed match point. """ - ... def read_fasta(fp): # -> Generator[tuple[Unknown, LiteralString], Any, None]: """ @@ -183,7 +178,6 @@ def read_fasta(fp): # -> Generator[tuple[Unknown, LiteralString], Any, None]: Change 1.5.3: Now uses H[1] rather than T for labile hydrogen. """ - ... def _(code, V, formula, name): ... diff --git a/typings/periodictable/formulas.pyi b/typings/periodictable/formulas.pyi index 82781f1..8c240fb 100644 --- a/typings/periodictable/formulas.pyi +++ b/typings/periodictable/formulas.pyi @@ -53,7 +53,6 @@ def mix_by_weight(*args, **kw): # -> Formula: density calculation assumes the cell volume remains constant for the original materials, which is not in general the case. """ - ... def mix_by_volume(*args, **kw): # -> Formula: """ @@ -97,11 +96,8 @@ def mix_by_volume(*args, **kw): # -> Formula: assumes the cell volume remains constant for the original materials, which is not in general the case. """ - ... -def formula( - compound=..., density=..., natural_density=..., name=..., table=... -): +def formula(compound=..., density=..., natural_density=..., name=..., table=...): r""" Construct a chemical formula representation from a string, a dictionary of atoms or another formula. @@ -129,20 +125,21 @@ def formula( After creating a formula, a rough estimate of the density can be computed using:: - formula.density = formula.molecular_mass/formula.volume(packing_factor=...) + formula.density = formula.molecular_mass / formula.volume(packing_factor=...) The volume() calculation uses the covalent radii of the components and the known packing factor or crystal structure name. If the lattice constants for the crystal are known, then they can be used instead:: - formula.density = formula.molecular_mass/formula.volume(a, b, c, alpha, beta, gamma) + formula.density = formula.molecular_mass / formula.volume( + a, b, c, alpha, beta, gamma + ) Formulas are designed for calculating quantities such as molar mass and scattering length density, not for representing bonds or atom positions. The representations are simple, but preserve some of the structure for display purposes. """ - ... class Formula: """ @@ -162,7 +159,6 @@ class Formula: the *count* as the total number of each element or isotope in the chemical formula, summed across all subgroups. """ - ... @property def hill(self): """ @@ -172,7 +168,6 @@ class Formula: first followed by hydrogen then the remaining elements in alphabetical order. """ - ... def natural_mass_ratio(self): # -> float: """ Natural mass to isotope mass ratio. @@ -186,7 +181,6 @@ class Formula: preserved with isotope substitution, then the ratio of the masses will be the ratio of the densities. """ - ... @property def natural_density(self): """ @@ -196,7 +190,6 @@ class Formula: replaced by the naturally occurring abundance of the element without changing the cell volume. """ - ... @natural_density.setter def natural_density(self, natural_density): ... @property @@ -207,7 +200,6 @@ class Formula: Molar mass of the molecule. Use molecular_mass to get the mass in grams. """ - ... @property def molecular_mass(self): # -> float: """ @@ -215,19 +207,16 @@ class Formula: Mass of the molecule in grams. """ - ... @property def charge(self): # -> int: """ Net charge of the molecule. """ - ... @property def mass_fraction(self): # -> dict[Unknown, Unknown]: """ Fractional mass representation of each element/isotope/ion. """ - ... def volume(self, *args, **kw): r""" Estimate unit cell volume. @@ -273,14 +262,13 @@ class Formula: Using the cell volume, mass density can be set with:: - formula.density = n*formula.molecular_mass/formula.volume() + formula.density = n * formula.molecular_mass / formula.volume() where n is the number of molecules per unit cell. Note: a single non-keyword argument is interpreted as a packing factor rather than a lattice spacing of 'a'. """ - ... @require_keywords def neutron_sld( self, wavelength=..., energy=... @@ -302,7 +290,6 @@ class Formula: .. deprecated:: 0.95 Use periodictable.neutron_sld(formula) instead. """ - ... @require_keywords def xray_sld( self, energy=..., wavelength=... @@ -329,12 +316,10 @@ class Formula: .. deprecated:: 0.95 Use periodictable.xray_sld(formula) instead. """ - ... def change_table(self, table): # -> Self@Formula: """ Replace the table used for the components of the formula. """ - ... def replace(self, source, target, portion=...): """ Create a new formula with one atom/isotope substituted for another. @@ -347,30 +332,24 @@ class Formula: *portion* is the proportion of source which is substituted for target. """ - ... def __eq__(self, other) -> bool: """ Return True if two formulas represent the same structure. Note that they may still have different names and densities. Note: use hill representation for an order independent comparison. """ - ... def __add__(self, other): # -> Formula: """ Join two formulas. """ - ... def __iadd__(self, other): # -> Self@Formula: """ Extend a formula with another. """ - ... def __rmul__(self, other): # -> Self@Formula: """ Provide a multiplier for formula. """ - ... - def __str__(self) -> str: ... def __repr__(self): ... LENGTH_UNITS = ... @@ -396,7 +375,6 @@ def formula_grammar(table): # -> ParserElement: an *element* or a list of pairs (*count, fragment*). """ - ... _PARSER_CACHE = ... @@ -405,4 +383,3 @@ def parse_formula(formula_str, table=...): Parse a chemical formula, returning a structure with elements from the given periodic table. """ - ... diff --git a/typings/periodictable/magnetic_ff.pyi b/typings/periodictable/magnetic_ff.pyi index c08e273..2ffa478 100644 --- a/typings/periodictable/magnetic_ff.pyi +++ b/typings/periodictable/magnetic_ff.pyi @@ -18,13 +18,11 @@ def formfactor_0(j0, q): """ Returns the scattering potential for form factor *j0* at the given *q*. """ - ... def formfactor_n(jn, q): """ Returns the scattering potential for form factor *jn* at the given *q*. """ - ... class MagneticFormFactor: """ @@ -55,8 +53,9 @@ class MagneticFormFactor: >>> import periodictable >>> ion = periodictable.Fe.ion[2] - >>> print("[%.5f, %.5f, %.5f]" - ... % tuple(ion.magnetic_ff[ion.charge].M_Q([0, 0.1, 0.2]))) + >>> print( + ... "[%.5f, %.5f, %.5f]" % tuple(ion.magnetic_ff[ion.charge].M_Q([0, 0.1, 0.2])) + ... ) [1.00000, 0.99935, 0.99741] """ @@ -64,23 +63,17 @@ class MagneticFormFactor: M = ... def j0_Q(self, Q): """Returns *j0* scattering potential at *Q* |1/Ang|""" - ... def j2_Q(self, Q): """Returns *j2* scattering potential at *Q* |1/Ang|""" - ... def j4_Q(self, Q): """Returns *j4* scattering potential at *Q* |1/Ang|""" - ... def j6_Q(self, Q): """Returns j6 scattering potential at *Q* |1/Ang|""" - ... def J_Q(self, Q): """Returns J scattering potential at *Q* |1/Ang|""" - ... M_Q = ... def init(table, reload=...): # -> None: """Add magnetic form factor properties to the periodic table""" - ... CFML_DATA = ... diff --git a/typings/periodictable/mass.pyi b/typings/periodictable/mass.pyi index 5d1522c..1375d2e 100644 --- a/typings/periodictable/mass.pyi +++ b/typings/periodictable/mass.pyi @@ -66,7 +66,6 @@ def mass(isotope): *Coursey. J. S., Schwab. D. J, and Dragoset. R. A., NIST Atomic Weights and Isotopic Composition Database.* """ - ... def abundance(isotope): """ @@ -82,10 +81,8 @@ def abundance(isotope): *Coursey. J. S., Schwab. D. J, and Dragoset. R. A., NIST Atomic Weights and Isotopic Composition Database.* """ - ... def init(table, reload=...): # -> None: """Add mass attribute to period table elements and isotopes""" - ... massdata = ... diff --git a/typings/periodictable/nsf.pyi b/typings/periodictable/nsf.pyi index f7fe311..01d9918 100644 --- a/typings/periodictable/nsf.pyi +++ b/typings/periodictable/nsf.pyi @@ -187,7 +187,6 @@ def neutron_wavelength(energy): # -> NDArray[Any]: $m_n$ = neutron mass in kg """ - ... def neutron_wavelength_from_velocity(velocity): r""" @@ -211,7 +210,6 @@ def neutron_wavelength_from_velocity(velocity): $m_n$ = neutron mass in kg """ - ... def neutron_energy(wavelength): # -> NDArray[floating[Any]]: r""" @@ -235,7 +233,6 @@ def neutron_energy(wavelength): # -> NDArray[floating[Any]]: $m_n$ = neutron mass in kg """ - ... _4PI_100 = ... @@ -327,6 +324,7 @@ class Neutron: .. Note:: 1 barn = 100 |fm^2| """ + b_c = ... b_c_units = ... b_c_i = ... @@ -352,13 +350,9 @@ class Neutron: is_energy_dependent = ... nsf_table = ... def __init__(self) -> None: ... - def __str__(self) -> str: ... def has_sld(self): # -> bool: """Returns *True* if sld is defined for this element/isotope.""" - ... - def scattering_by_wavelength( - self, wavelength - ): # -> tuple[Unknown, Unknown]: + def scattering_by_wavelength(self, wavelength): # -> tuple[Unknown, Unknown]: r""" Return scattering length and total cross section for each wavelength. @@ -377,7 +371,6 @@ class Neutron: *sigma_s* \: float(s) | barn """ - ... @require_keywords def sld( self, wavelength=... @@ -397,7 +390,6 @@ class Neutron: See :func:`neutron_scattering` for details. """ - ... @require_keywords def scattering( self, wavelength=... @@ -424,14 +416,12 @@ class Neutron: See :func:`neutron_scattering` for details. """ - ... def energy_dependent_init(table): ... def init(table, reload=...): # -> None: """ Loads the Rauch table from the neutron data book. """ - ... @require_keywords def neutron_scattering( @@ -683,7 +673,6 @@ def neutron_scattering( t_u\,({\rm cm}) &= 1/(\Sigma_{\rm s}\, 1/{\rm cm} \,+\, \Sigma_{\rm abs}\, 1/{\rm cm}) """ - ... def neutron_sld( *args, **kw @@ -714,7 +703,6 @@ def neutron_sld( Returns the scattering length density of the compound. See :func:`neutron_scattering` for details. """ - ... def neutron_sld_from_atoms( *args, **kw @@ -725,7 +713,6 @@ def neutron_sld_from_atoms( :func:`neutron_sld` accepts dictionaries of \{atom\: count\}. """ - ... def D2O_match(compound, **kw): # -> tuple[Unknown | float, Unknown]: """ @@ -744,7 +731,6 @@ def D2O_match(compound, **kw): # -> tuple[Unknown | float, Unknown]: 100% you will need an additional constrast agent in the 100% D2O solvent to increase the SLD enough to match. """ - ... def D2O_sld( compound, volume_fraction=..., D2O_fraction=..., **kw @@ -781,13 +767,11 @@ def D2O_sld( Returns (real, imag, incoh) SLD. """ - ... def mix_values(a, b, fraction): # -> tuple[Unknown, ...]: """ Mix two tuples with floating point values according to fraction of a. """ - ... def neutron_composite_sld( materials, wavelength=... @@ -817,7 +801,6 @@ def neutron_composite_sld( the calculation consists of a few simple array operations regardless of the size of the material fragments. """ - ... def sld_plot(table=...): # -> None: r""" @@ -829,7 +812,6 @@ def sld_plot(table=...): # -> None: :Returns: None """ - ... nsftable = ... nsftableI = ... @@ -840,7 +822,6 @@ def fix_number(str): # -> float | None: uncertainty. Also accepts a limited range, e.g., <1e-6, which is converted as 1e-6. Missing values are set to 0. """ - ... def sld_table(wavelength=..., table=..., isotopes=...): # -> None: r""" @@ -872,7 +853,6 @@ def sld_table(wavelength=..., table=..., isotopes=...): # -> None: 248-Cm 248.072 13.569 2.536 0.000 0.207 * Energy dependent cross sections """ - ... def energy_dependent_table(table=...): # -> None: r""" @@ -896,7 +876,6 @@ def energy_dependent_table(table=...): # -> None: Yb-168 Hg-196 Hg-199 """ - ... def compare(fn1, fn2, table=..., tol=...): ... def absorption_comparison_table(table=..., tol=...): # -> None: @@ -925,7 +904,9 @@ def absorption_comparison_table(table=..., tol=...): # -> None: Example - >>> absorption_comparison_table (tol=0.5) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + >>> absorption_comparison_table( + ... tol=0.5 + ... ) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE Comparison of absorption and (-2000 lambda b_c_i) 3-He 5333.00 5322.08 0.2% Li 70.50 ---- @@ -936,7 +917,6 @@ def absorption_comparison_table(table=..., tol=...): # -> None: ... """ - ... def coherent_comparison_table(table=..., tol=...): # -> None: r""" @@ -956,7 +936,7 @@ def coherent_comparison_table(table=..., tol=...): # -> None: Example - >>> coherent_comparison_table (tol=0.5) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + >>> coherent_comparison_table(tol=0.5) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE Comparison of (4 pi |b_c|^2/100) and coherent n 172.03 43.01 300.0% 1-n 172.03 43.01 300.0% @@ -968,7 +948,6 @@ def coherent_comparison_table(table=..., tol=...): # -> None: ... """ - ... def total_comparison_table(table=..., tol=...): # -> None: r""" @@ -987,7 +966,7 @@ def total_comparison_table(table=..., tol=...): # -> None: Example - >>> total_comparison_table (tol=0.1) + >>> total_comparison_table(tol=0.1) Comparison of total cross section to (coherent + incoherent) n 43.01 ---- 1-n 43.01 ---- @@ -1001,7 +980,6 @@ def total_comparison_table(table=..., tol=...): # -> None: 187-Os 13.00 13.30 -2.3% """ - ... def incoherent_comparison_table(table=..., tol=...): # -> None: r""" @@ -1018,7 +996,9 @@ def incoherent_comparison_table(table=..., tol=...): # -> None: Example - >>> incoherent_comparison_table (tol=0.5) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + >>> incoherent_comparison_table( + ... tol=0.5 + ... ) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE Comparison of incoherent and (total - 4 pi |b_c|^2/100) Sc 4.50 5.10 -11.8% 45-Sc 4.50 5.10 -11.8% @@ -1028,13 +1008,11 @@ def incoherent_comparison_table(table=..., tol=...): # -> None: ... """ - ... def print_scattering(compound, wavelength=...): # -> None: """ Print the scattering for a single compound. """ - ... def main(): # -> None: """ @@ -1052,6 +1030,5 @@ def main(): # -> None: sigma_c: 3.37503 sigma_i: 0.000582313 sigma_a: 0.402605 1/cm 1/e penetration: 2.23871 cm """ - ... if __name__ == "__main__": ... diff --git a/typings/periodictable/nsf_resonances.pyi b/typings/periodictable/nsf_resonances.pyi index f10858f..85d6095 100644 --- a/typings/periodictable/nsf_resonances.pyi +++ b/typings/periodictable/nsf_resonances.pyi @@ -47,9 +47,7 @@ class RareEarthIsotope: *A*, *B*, *C* are the fitted real + imaginary constants for the isotope, with units 1/eV for B and 1/eV^2 for C. """ - ... def f(self, E): ... - def __str__(self) -> str: ... Sm_nat = ... Gd_155 = ... diff --git a/typings/periodictable/plot.pyi b/typings/periodictable/plot.pyi index 562283c..0b2d09b 100644 --- a/typings/periodictable/plot.pyi +++ b/typings/periodictable/plot.pyi @@ -22,4 +22,3 @@ def table_plot(data, form=..., label=..., title=...): # -> None: :Returns: None """ - ... diff --git a/typings/periodictable/util.pyi b/typings/periodictable/util.pyi index 5b3c937..528f20d 100644 --- a/typings/periodictable/util.pyi +++ b/typings/periodictable/util.pyi @@ -33,7 +33,6 @@ def cell_volume(a=..., b=..., c=..., alpha=..., beta=..., gamma=...): V = a b c \sqrt{1 - \cos^2 \alpha - \cos^2 \beta - \cos^2 \gamma + 2 \cos \alpha \cos \beta \cos \gamma} """ - ... def require_keywords( function, @@ -45,7 +44,8 @@ def require_keywords( For example: >>> @require_keywords - ... def fn(a, b, c=3): pass + ... def fn(a, b, c=3): + ... pass >>> fn(1, 2, 3) Traceback (most recent call last): ... @@ -56,7 +56,8 @@ def require_keywords( Variable arguments are not currently supported: >>> @require_keywords - ... def fn(a, b, c=6, *args, **kw): pass + ... def fn(a, b, c=6, *args, **kw): + ... pass Traceback (most recent call last): ... NotImplementedError: only named arguments for now @@ -68,4 +69,3 @@ def require_keywords( use the \*args, \*\*kw call style. Python 3+ provides the '\*' call signature element which will force all keywords after '\*' to be named. """ - ... diff --git a/typings/periodictable/xsf.pyi b/typings/periodictable/xsf.pyi index 866db19..cb7a8ac 100644 --- a/typings/periodictable/xsf.pyi +++ b/typings/periodictable/xsf.pyi @@ -217,7 +217,6 @@ def xray_wavelength(energy): # -> NDArray[floating[Any]]: $c$ = speed of light in m/s """ - ... def xray_energy(wavelength): # -> NDArray[floating[Any]]: r""" @@ -241,7 +240,6 @@ def xray_energy(wavelength): # -> NDArray[floating[Any]]: $c$ = speed of light in m/s """ - ... class Xray: """ @@ -276,7 +274,6 @@ class Xray: scattering factors database at the Lawrence Berkeley Laboratory Center for X-ray Optics. """ - ... def f0(self, Q): r""" Isotropic X-ray scattering factors *f0* for the input Q. @@ -300,7 +297,6 @@ class Xray: D. Wassmaier, A. Kerfel, Acta Crystallogr. A51 (1995) 416. http://dx.doi.org/10.1107/S0108767394013292 """ - ... @require_keywords def sld( self, wavelength=..., energy=... @@ -340,7 +336,6 @@ class Xray: Data comes from the Henke Xray scattering factors database at the Lawrence Berkeley Laboratory Center for X-ray Optics. """ - ... @require_keywords def xray_sld( @@ -368,7 +363,6 @@ def xray_sld( :Raises: *AssertionError* : *density* or *wavelength*/*energy* is missing. """ - ... @require_keywords def index_of_refraction( @@ -398,7 +392,6 @@ def index_of_refraction( Formula taken from http://xdb.lbl.gov (section 1.7) and checked against http://henke.lbl.gov/optical_constants/getdb2.html """ - ... @require_keywords def mirror_reflectivity( @@ -438,7 +431,6 @@ def mirror_reflectivity( Formula taken from http://xdb.lbl.gov (section 4.2) and checked against http://henke.lbl.gov/optical_constants/mirror2.html """ - ... def xray_sld_from_atoms( *args, **kw @@ -448,7 +440,6 @@ def xray_sld_from_atoms( :func:`xray_sld` now accepts a dictionary of *{atom: count}* directly. """ - ... spectral_lines_data = ... @@ -456,7 +447,6 @@ def init_spectral_lines(table): # -> None: """ Sets the K_alpha and K_beta1 wavelengths for select elements """ - ... def init(table, reload=...): ... def plot_xsf(el): # -> None: @@ -468,7 +458,6 @@ def plot_xsf(el): # -> None: :Returns: None """ - ... def sld_table(wavelength=..., table=...): # -> None: """ @@ -484,7 +473,7 @@ def sld_table(wavelength=..., table=...): # -> None: Example - >>> sld_table() # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + >>> sld_table() # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE X-ray scattering length density for 1.5418 Ang El rho irho H 1.19 0.00 @@ -501,7 +490,6 @@ def sld_table(wavelength=..., table=...): # -> None: Mg 14.78 0.22 ... """ - ... def emission_table(table=...): # -> None: """ @@ -515,7 +503,7 @@ def emission_table(table=...): # -> None: Example - >>> emission_table() # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + >>> emission_table() # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE El Kalpha Kbeta1 Ne 14.6102 14.4522 Na 11.9103 11.5752 @@ -524,4 +512,3 @@ def emission_table(table=...): # -> None: Si 7.1263 6.7531 ... """ - ...