From 885fe90ec6a355bed5fa838be952b7ce85282293 Mon Sep 17 00:00:00 2001 From: "Eric G. Kratz" Date: Tue, 5 Mar 2024 11:33:46 -0500 Subject: [PATCH] Rename have_optional_dependency (#3866) * Rename have_optional_dependency * Change log * Fix import * style: pre-commit fixes * Update pybamm/util.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Arjun Verma --- CHANGELOG.md | 1 + CONTRIBUTING.md | 8 ++++---- pybamm/__init__.py | 2 +- pybamm/citations.py | 12 +++++------ pybamm/expression_tree/functions.py | 4 ++-- pybamm/expression_tree/symbol.py | 10 +++++----- pybamm/expression_tree/unary_operators.py | 8 +++++--- pybamm/meshes/scikit_fem_submeshes.py | 4 ++-- pybamm/models/base_model.py | 4 +--- pybamm/plotting/plot.py | 4 ++-- pybamm/plotting/plot2D.py | 4 ++-- pybamm/plotting/plot_summary_variables.py | 4 ++-- pybamm/plotting/plot_voltage_components.py | 4 ++-- pybamm/plotting/quick_plot.py | 20 +++++++++---------- pybamm/simulation.py | 4 ++-- .../spatial_methods/scikit_finite_element.py | 16 +++++++-------- pybamm/util.py | 9 +++------ tests/unit/test_util.py | 4 ++-- 18 files changed, 60 insertions(+), 62 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 069093015a..41856b3fd7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ ## Breaking changes +- Renamed "have_optional_dependency" to "import_optional_dependency" ([#3866](https://github.com/pybamm-team/PyBaMM/pull/3866)) - Integrated the `[latexify]` extra into the core PyBaMM package, deprecating the `pybamm[latexify]` set of optional dependencies. SymPy is now a required dependency and will be installed upon installing PyBaMM ([#3848](https://github.com/pybamm-team/PyBaMM/pull/3848)) - Renamed "testing" argument for plots to "show_plot" and flipped its meaning (show_plot=True is now the default and shows the plot) ([#3842](https://github.com/pybamm-team/PyBaMM/pull/3842)) - Dropped support for BPX version 0.3.0 and below ([#3414](https://github.com/pybamm-team/PyBaMM/pull/3414)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b9800dcd61..fc8e848bb5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -104,13 +104,13 @@ Only 'core pybamm' is installed by default. The others have to be specified expl PyBaMM utilizes optional dependencies to allow users to choose which additional libraries they want to use. Managing these optional dependencies and their imports is essential to provide flexibility to PyBaMM users. -PyBaMM provides a utility function `have_optional_dependency`, to check for the availability of optional dependencies within methods. This function can be used to conditionally import optional dependencies only if they are available. Here's how to use it: +PyBaMM provides a utility function `import_optional_dependency`, to check for the availability of optional dependencies within methods. This function can be used to conditionally import optional dependencies only if they are available. Here's how to use it: Optional dependencies should never be imported at the module level, but always inside methods. For example: ``` def use_pybtex(x,y,z): - pybtex = have_optional_dependency("pybtex") + pybtex = import_optional_dependency("pybtex") ... ``` @@ -118,7 +118,7 @@ While importing a specific module instead of an entire package/library: ```python def use_parse_file(x, y, z): - parse_file = have_optional_dependency("pybtex.database", "parse_file") + parse_file = import_optional_dependency("pybtex.database", "parse_file") ... ``` @@ -143,7 +143,7 @@ class TestUtil(TestCase): pybamm.function_using_pybtex(x, y, z) # Test that the function works when pybtex is available - sys.modules["pybtex"] = pybamm.util.have_optional_dependency("pybtex") + sys.modules["pybtex"] = pybamm.util.import_optional_dependency("pybtex") pybamm.function_using_pybtex(x, y, z) ``` diff --git a/pybamm/__init__.py b/pybamm/__init__.py index ab2e72ed28..c2654ea9cf 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -47,7 +47,7 @@ get_parameters_filepath, have_jax, install_jax, - have_optional_dependency, + import_optional_dependency, is_jax_compatible, get_git_commit_info, ) diff --git a/pybamm/citations.py b/pybamm/citations.py index 9e649048e7..16b86419d6 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -7,7 +7,7 @@ import os import warnings from sys import _getframe -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class Citations: @@ -74,7 +74,7 @@ def read_citations(self): """Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited by passing a BibTeX citation to :meth:`register`. """ - parse_file = have_optional_dependency("pybtex.database", "parse_file") + parse_file = import_optional_dependency("pybtex.database", "parse_file") citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib") bib_data = parse_file(citations_file, bib_format="bibtex") for key, entry in bib_data.entries.items(): @@ -85,7 +85,7 @@ def _add_citation(self, key, entry): previous entry is overwritten """ - Entry = have_optional_dependency("pybtex.database", "Entry") + Entry = import_optional_dependency("pybtex.database", "Entry") # Check input types are correct if not isinstance(key, str) or not isinstance(entry, Entry): raise TypeError() @@ -151,8 +151,8 @@ def _parse_citation(self, key): key: str A BibTeX formatted citation """ - PybtexError = have_optional_dependency("pybtex.scanner", "PybtexError") - parse_string = have_optional_dependency("pybtex.database", "parse_string") + PybtexError = import_optional_dependency("pybtex.scanner", "PybtexError") + parse_string = import_optional_dependency("pybtex.database", "parse_string") try: # Parse string as a bibtex citation, and check that a citation was found bib_data = parse_string(key, bib_format="bibtex") @@ -219,7 +219,7 @@ def print(self, filename=None, output_format="text", verbose=False): """ # Parse citations that were not known keys at registration, but do not # fail if they cannot be parsed - pybtex = have_optional_dependency("pybtex") + pybtex = import_optional_dependency("pybtex") try: for key in self._unknown_citations: self._parse_citation(key) diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 7f0986441f..7f43f0b7c2 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -10,7 +10,7 @@ from typing_extensions import TypeVar import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class Function(pybamm.Symbol): @@ -98,7 +98,7 @@ def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float): Derivative with respect to child number 'idx'. See :meth:`pybamm.Symbol._diff()`. """ - autograd = have_optional_dependency("autograd") + autograd = import_optional_dependency("autograd") # Store differentiated function, needed in case we want to convert to CasADi if self.derivative == "autograd": return Function( diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 1d0d880ce6..7dc6acdf23 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Sequence, cast import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.expression_tree.printing.print_name import prettify_print_name if TYPE_CHECKING: # pragma: no cover @@ -479,7 +479,7 @@ def render(self): # pragma: no cover """ Print out a visual representation of the tree (this node and its children) """ - anytree = have_optional_dependency("anytree") + anytree = import_optional_dependency("anytree") for pre, _, node in anytree.RenderTree(self): if isinstance(node, pybamm.Scalar) and node.name != str(node.value): print(f"{pre}{node.name} = {node.value}") @@ -498,7 +498,7 @@ def visualise(self, filename: str): filename to output, must end in ".png" """ - DotExporter = have_optional_dependency("anytree.exporter", "DotExporter") + DotExporter = import_optional_dependency("anytree.exporter", "DotExporter") # check that filename ends in .png. if filename[-4:] != ".png": raise ValueError("filename should end in .png") @@ -518,7 +518,7 @@ def relabel_tree(self, symbol: Symbol, counter: int): Finds all children of a symbol and assigns them a new id so that they can be visualised properly using the graphviz output """ - anytree = have_optional_dependency("anytree") + anytree = import_optional_dependency("anytree") name = symbol.name if name == "div": name = "∇⋅" @@ -561,7 +561,7 @@ def pre_order(self): a b """ - anytree = have_optional_dependency("anytree") + anytree = import_optional_dependency("anytree") return anytree.PreOrderIter(self) def __str__(self): diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index cb5bc905fd..9669c2596a 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -7,7 +7,7 @@ from scipy.sparse import csr_matrix, issparse import sympy import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.type_definitions import DomainsType @@ -450,7 +450,9 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" - sympy_Gradient = have_optional_dependency("sympy.vector.operators", "Gradient") + sympy_Gradient = import_optional_dependency( + "sympy.vector.operators", "Gradient" + ) return sympy_Gradient(child) @@ -484,7 +486,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" - sympy_Divergence = have_optional_dependency( + sympy_Divergence = import_optional_dependency( "sympy.vector.operators", "Divergence" ) return sympy_Divergence(child) diff --git a/pybamm/meshes/scikit_fem_submeshes.py b/pybamm/meshes/scikit_fem_submeshes.py index 82a7bd72f1..e52f58f069 100644 --- a/pybamm/meshes/scikit_fem_submeshes.py +++ b/pybamm/meshes/scikit_fem_submeshes.py @@ -5,7 +5,7 @@ from .meshes import SubMesh import numpy as np -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class ScikitSubMesh2D(SubMesh): @@ -27,7 +27,7 @@ class ScikitSubMesh2D(SubMesh): """ def __init__(self, edges, coord_sys, tabs): - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") self.edges = edges self.nodes = dict.fromkeys(["y", "z"]) for var in self.nodes.keys(): diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 3a27249083..6da534c783 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -13,7 +13,6 @@ import pybamm from pybamm.expression_tree.operations.serialise import Serialise -import sympy class BaseModel: @@ -1185,8 +1184,7 @@ def latexify(self, filename=None, newline=True, output_variables=None): This will return first five model equations >>> model.latexify(newline=False)[1:5] """ - if sympy: - from pybamm.expression_tree.operations.latexify import Latexify + from pybamm.expression_tree.operations.latexify import Latexify return Latexify(self, filename, newline).latexify( output_variables=output_variables diff --git a/pybamm/plotting/plot.py b/pybamm/plotting/plot.py index cf5c972a87..4037ab8fbf 100644 --- a/pybamm/plotting/plot.py +++ b/pybamm/plotting/plot.py @@ -3,7 +3,7 @@ # import pybamm from .quick_plot import ax_min, ax_max -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency def plot(x, y, ax=None, show_plot=True, **kwargs): @@ -27,7 +27,7 @@ def plot(x, y, ax=None, show_plot=True, **kwargs): Keyword arguments, passed to plt.plot """ - plt = have_optional_dependency("matplotlib.pyplot") + plt = import_optional_dependency("matplotlib.pyplot") if not isinstance(x, pybamm.Array): raise TypeError("x must be 'pybamm.Array'") diff --git a/pybamm/plotting/plot2D.py b/pybamm/plotting/plot2D.py index a37cd1e2ed..7d1f3c6bae 100644 --- a/pybamm/plotting/plot2D.py +++ b/pybamm/plotting/plot2D.py @@ -3,7 +3,7 @@ # import pybamm from .quick_plot import ax_min, ax_max -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency def plot2D(x, y, z, ax=None, show_plot=True, **kwargs): @@ -27,7 +27,7 @@ def plot2D(x, y, z, ax=None, show_plot=True, **kwargs): only display the plot after plt.show() has been called. """ - plt = have_optional_dependency("matplotlib.pyplot") + plt = import_optional_dependency("matplotlib.pyplot") if not isinstance(x, pybamm.Array): raise TypeError("x must be 'pybamm.Array'") diff --git a/pybamm/plotting/plot_summary_variables.py b/pybamm/plotting/plot_summary_variables.py index 33642c4d5a..bd4db0ee6c 100644 --- a/pybamm/plotting/plot_summary_variables.py +++ b/pybamm/plotting/plot_summary_variables.py @@ -3,7 +3,7 @@ # import numpy as np import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency def plot_summary_variables( @@ -27,7 +27,7 @@ def plot_summary_variables( Keyword arguments, passed to plt.subplots. """ - plt = have_optional_dependency("matplotlib.pyplot") + plt = import_optional_dependency("matplotlib.pyplot") if isinstance(solutions, pybamm.Solution): solutions = [solutions] diff --git a/pybamm/plotting/plot_voltage_components.py b/pybamm/plotting/plot_voltage_components.py index 3b155b71de..0d1bb7b573 100644 --- a/pybamm/plotting/plot_voltage_components.py +++ b/pybamm/plotting/plot_voltage_components.py @@ -3,7 +3,7 @@ # import numpy as np -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.simulation import Simulation from pybamm.solvers.solution import Solution @@ -42,7 +42,7 @@ def plot_voltage_components( solution = input_data.solution elif isinstance(input_data, Solution): solution = input_data - plt = have_optional_dependency("matplotlib.pyplot") + plt = import_optional_dependency("matplotlib.pyplot") # Set a default value for alpha, the opacity kwargs_fill = {"alpha": 0.6, **kwargs_fill} diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index e06f419ea4..7bbc3d2a86 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -5,7 +5,7 @@ import numpy as np import pybamm from collections import defaultdict -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class LoopList(list): @@ -46,7 +46,7 @@ def split_long_string(title, max_words=None): def close_plots(): """Close all open figures""" - plt = have_optional_dependency("matplotlib.pyplot") + plt = import_optional_dependency("matplotlib.pyplot") plt.close("all") @@ -473,10 +473,10 @@ def plot(self, t, dynamic=False): Dimensional time (in 'time_units') at which to plot. """ - plt = have_optional_dependency("matplotlib.pyplot") - gridspec = have_optional_dependency("matplotlib.gridspec") - cm = have_optional_dependency("matplotlib", "cm") - colors = have_optional_dependency("matplotlib", "colors") + plt = import_optional_dependency("matplotlib.pyplot") + gridspec = import_optional_dependency("matplotlib.gridspec") + cm = import_optional_dependency("matplotlib", "cm") + colors = import_optional_dependency("matplotlib", "colors") t_in_seconds = t * self.time_scaling_factor self.fig = plt.figure(figsize=self.figsize) @@ -674,8 +674,8 @@ def dynamic_plot(self, show_plot=True, step=None): continuous_update=False, ) else: - plt = have_optional_dependency("matplotlib.pyplot") - Slider = have_optional_dependency("matplotlib.widgets", "Slider") + plt = import_optional_dependency("matplotlib.pyplot") + Slider = import_optional_dependency("matplotlib.widgets", "Slider") # create an initial plot at time self.min_t self.plot(self.min_t, dynamic=True) @@ -779,8 +779,8 @@ def create_gif(self, number_of_images=80, duration=0.1, output_filename="plot.gi Name of the generated GIF file. """ - imageio = have_optional_dependency("imageio.v2") - plt = have_optional_dependency("matplotlib.pyplot") + imageio = import_optional_dependency("imageio.v2") + plt = import_optional_dependency("matplotlib.pyplot") # time stamps at which the images/plots will be created time_array = np.linspace(self.min_t, self.max_t, num=number_of_images) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 6a02f7e2f8..c7b5efe983 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -11,7 +11,7 @@ import sys from functools import lru_cache from datetime import timedelta -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.expression_tree.operations.serialise import Serialise @@ -701,7 +701,7 @@ def solve( # check if a user has tqdm installed if showprogress: - tqdm = have_optional_dependency("tqdm") + tqdm = import_optional_dependency("tqdm") cycle_lengths = tqdm.tqdm( self.experiment.cycle_lengths, desc="Cycling", diff --git a/pybamm/spatial_methods/scikit_finite_element.py b/pybamm/spatial_methods/scikit_finite_element.py index 41957b10cc..e212ef71f7 100644 --- a/pybamm/spatial_methods/scikit_finite_element.py +++ b/pybamm/spatial_methods/scikit_finite_element.py @@ -7,7 +7,7 @@ from scipy.sparse.linalg import inv import numpy as np -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class ScikitFiniteElement(pybamm.SpatialMethod): @@ -88,7 +88,7 @@ def gradient(self, symbol, discretised_symbol, boundary_conditions): to the y-component of the gradient and the second column corresponds to the z component of the gradient. """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") domain = symbol.domain[0] mesh = self.mesh[domain] @@ -144,7 +144,7 @@ def gradient_matrix(self, symbol, boundary_conditions): :class:`pybamm.Matrix` The (sparse) finite element gradient matrix for the domain """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] @@ -190,7 +190,7 @@ def laplacian(self, symbol, discretised_symbol, boundary_conditions): Contains the result of acting the discretised gradient on the child discretised_symbol """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") domain = symbol.domain[0] mesh = self.mesh[domain] @@ -258,7 +258,7 @@ def stiffness_matrix(self, symbol, boundary_conditions): :class:`pybamm.Matrix` The (sparse) finite element stiffness matrix for the domain """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] @@ -321,7 +321,7 @@ def definite_integral_matrix(self, child, vector_type="row"): :class:`pybamm.Matrix` The finite element integral vector for the domain """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") # get primary domain mesh domain = child.domain[0] mesh = self.mesh[domain] @@ -383,7 +383,7 @@ def boundary_integral_vector(self, domain, region): :class:`pybamm.Matrix` The finite element integral vector for the domain """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") # get primary domain mesh mesh = self.mesh[domain[0]] @@ -501,7 +501,7 @@ def assemble_mass_form(self, symbol, boundary_conditions, region="interior"): :class:`pybamm.Matrix` The (sparse) mass matrix for the spatial method. """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] diff --git a/pybamm/util.py b/pybamm/util.py index b4972773c1..1d0c814365 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -356,19 +356,16 @@ def install_jax(arguments=None): # pragma: no cover # https://docs.pybamm.org/en/latest/source/user_guide/contributing.html#managing-optional-dependencies-and-their-imports -def have_optional_dependency(module_name, attribute=None): +def import_optional_dependency(module_name, attribute=None): err_msg = f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details." try: - # Attempt to import the specified module module = importlib.import_module(module_name) - if attribute: - # If an attribute is specified, check if it's available if hasattr(module, attribute): imported_attribute = getattr(module, attribute) - return imported_attribute # Return the imported attribute + # Return the imported attribute + return imported_attribute else: - # Raise an ModuleNotFoundError if the attribute is not available raise ModuleNotFoundError(err_msg) # pragma: no cover else: # Return the entire module if no attribute is specified diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 24f204b6df..6fe43c096d 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -99,7 +99,7 @@ def test_git_commit_info(self): self.assertIsInstance(git_commit_info, str) self.assertEqual(git_commit_info[:2], "v2") - def test_have_optional_dependency(self): + def test_import_optional_dependency(self): with self.assertRaisesRegex( ModuleNotFoundError, "Optional dependency pybtex is not available." ): @@ -119,7 +119,7 @@ def test_have_optional_dependency(self): sym.visualise(test_name) sys.modules["pybtex"] = pybtex - pybamm.util.have_optional_dependency("pybtex") + pybamm.util.import_optional_dependency("pybtex") pybamm.print_citations()