diff --git a/pybamm/citations.py b/pybamm/citations.py index ca260c5cfd..da371bbd84 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -60,6 +60,7 @@ def _reset(self): self.register("Sulzer2021") self.register("Harris2020") + @staticmethod def _caller_name(): """ Returns the qualified name of classes that call :meth:`register` internally. diff --git a/pybamm/experiment/experiment.py b/pybamm/experiment/experiment.py index bc7db1da5a..a2a40981f3 100644 --- a/pybamm/experiment/experiment.py +++ b/pybamm/experiment/experiment.py @@ -39,7 +39,7 @@ class Experiment: def __init__( self, - operating_conditions: list[str], + operating_conditions: list[str | tuple[str]], period: str = "1 minute", temperature: float | None = None, termination: list[str] | None = None, diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 7694cbc170..0bb6168a7c 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -1,11 +1,17 @@ # # NumpyArray class # +from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, issparse +from typing import TYPE_CHECKING import pybamm from pybamm.util import have_optional_dependency +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType + +if TYPE_CHECKING: # pragma: no cover + import sympy class Array(pybamm.Symbol): @@ -36,13 +42,13 @@ class Array(pybamm.Symbol): def __init__( self, - entries, - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - entries_string=None, - ): + entries: np.ndarray | list[float] | csr_matrix, + name: str | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, + entries_string: str | None = None, + ) -> None: # if if isinstance(entries, list): entries = np.array(entries) @@ -59,8 +65,6 @@ def __init__( @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - if isinstance(snippet["entries"], dict): matrix = csr_matrix( ( @@ -73,14 +77,12 @@ def _from_json(cls, snippet: dict): else: matrix = snippet["entries"] - instance.__init__( + return cls( matrix, name=snippet["name"], domains=snippet["domains"], ) - return instance - @property def entries(self): return self._entries @@ -100,7 +102,7 @@ def entries_string(self): return self._entries_string @entries_string.setter - def entries_string(self, value): + def entries_string(self, value: None | tuple): # We must include the entries in the hash, since different arrays can be # indistinguishable by class, name and domain alone # Slightly different syntax for sparse and non-sparse matrices @@ -110,10 +112,10 @@ def entries_string(self, value): entries = self._entries if issparse(entries): dct = entries.__dict__ - self._entries_string = ["shape", str(dct["_shape"])] + entries_string = ["shape", str(dct["_shape"])] for key in ["data", "indices", "indptr"]: - self._entries_string += [key, dct[key].tobytes()] - self._entries_string = tuple(self._entries_string) + entries_string += [key, dct[key].tobytes()] + self._entries_string = tuple(entries_string) # self._entries_string = str(entries.__dict__) else: self._entries_string = (entries.tobytes(),) @@ -124,7 +126,7 @@ def set_id(self): (self.__class__, self.name, *self.entries_string, *tuple(self.domain)) ) - def _jac(self, variable): + def _jac(self, variable) -> pybamm.Matrix: """See :meth:`pybamm.Symbol._jac()`.""" # Return zeros of correct size jac = csr_matrix((self.size, variable.evaluation_array.count(True))) @@ -139,7 +141,13 @@ def create_copy(self): entries_string=self.entries_string, ) - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" return self._entries @@ -147,7 +155,7 @@ def is_constant(self): """See :meth:`pybamm.Symbol.is_constant()`.""" return True - def to_equation(self): + def to_equation(self) -> sympy.Array: """Returns the value returned by the node when evaluated.""" sympy = have_optional_dependency("sympy") entries_list = self.entries.tolist() @@ -178,7 +186,7 @@ def to_json(self): return json_dict -def linspace(start, stop, num=50, **kwargs): +def linspace(start: float, stop: float, num: int = 50, **kwargs) -> pybamm.Array: """ Creates a linearly spaced array by calling `numpy.linspace` with keyword arguments 'kwargs'. For a list of 'kwargs' see the @@ -187,7 +195,9 @@ def linspace(start, stop, num=50, **kwargs): return pybamm.Array(np.linspace(start, stop, num, **kwargs)) -def meshgrid(x, y, **kwargs): +def meshgrid( + x: pybamm.Array, y: pybamm.Array, **kwargs +) -> tuple[pybamm.Array, pybamm.Array]: """ Return coordinate matrices as from coordinate vectors by calling `numpy.meshgrid` with keyword arguments 'kwargs'. For a list of 'kwargs' diff --git a/pybamm/expression_tree/averages.py b/pybamm/expression_tree/averages.py index 81834d5871..d612906b71 100644 --- a/pybamm/expression_tree/averages.py +++ b/pybamm/expression_tree/averages.py @@ -1,6 +1,8 @@ # # Classes and methods for averaging # +from __future__ import annotations +from typing import Callable import pybamm @@ -14,13 +16,19 @@ class _BaseAverage(pybamm.Integral): The child node """ - def __init__(self, child, name, integration_variable): + def __init__( + self, + child: pybamm.Symbol, + name: str, + integration_variable: list[pybamm.IndependentVariable] + | pybamm.IndependentVariable, + ) -> None: super().__init__(child, integration_variable) self.name = name class XAverage(_BaseAverage): - def __init__(self, child): + def __init__(self, child: pybamm.Symbol) -> None: if all(n in child.domain[0] for n in ["negative", "particle"]): x = pybamm.standard_spatial_vars.x_n elif all(n in child.domain[0] for n in ["positive", "particle"]): @@ -30,56 +38,60 @@ def __init__(self, child): integration_variable = x super().__init__(child, "x-average", integration_variable) - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`UnaryOperator._unary_new_copy()`.""" return x_average(child) class YZAverage(_BaseAverage): - def __init__(self, child): + def __init__(self, child: pybamm.Symbol) -> None: y = pybamm.standard_spatial_vars.y z = pybamm.standard_spatial_vars.z - integration_variable = [y, z] + integration_variable: list[pybamm.IndependentVariable] = [y, z] super().__init__(child, "yz-average", integration_variable) - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`UnaryOperator._unary_new_copy()`.""" return yz_average(child) class ZAverage(_BaseAverage): - def __init__(self, child): - integration_variable = [pybamm.standard_spatial_vars.z] + def __init__(self, child: pybamm.Symbol) -> None: + integration_variable: list[pybamm.IndependentVariable] = [ + pybamm.standard_spatial_vars.z + ] super().__init__(child, "z-average", integration_variable) - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`UnaryOperator._unary_new_copy()`.""" return z_average(child) class RAverage(_BaseAverage): - def __init__(self, child): - integration_variable = [pybamm.SpatialVariable("r", child.domain)] + def __init__(self, child: pybamm.Symbol) -> None: + integration_variable: list[pybamm.IndependentVariable] = [ + pybamm.SpatialVariable("r", child.domain) + ] super().__init__(child, "r-average", integration_variable) - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`UnaryOperator._unary_new_copy()`.""" return r_average(child) class SizeAverage(_BaseAverage): - def __init__(self, child, f_a_dist): + def __init__(self, child: pybamm.Symbol, f_a_dist) -> None: R = pybamm.SpatialVariable("R", domains=child.domains, coord_sys="cartesian") - integration_variable = [R] + integration_variable: list[pybamm.IndependentVariable] = [R] super().__init__(child, "size-average", integration_variable) self.f_a_dist = f_a_dist - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`UnaryOperator._unary_new_copy()`.""" return size_average(child, f_a_dist=self.f_a_dist) -def x_average(symbol): +def x_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the x-direction. @@ -168,7 +180,7 @@ def x_average(symbol): return XAverage(symbol) -def z_average(symbol): +def z_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the z-direction. @@ -205,7 +217,7 @@ def z_average(symbol): return ZAverage(symbol) -def yz_average(symbol): +def yz_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the y-z-direction. @@ -239,11 +251,11 @@ def yz_average(symbol): return YZAverage(symbol) -def xyz_average(symbol): +def xyz_average(symbol: pybamm.Symbol) -> pybamm.Symbol: return yz_average(x_average(symbol)) -def r_average(symbol): +def r_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the r-direction. @@ -286,7 +298,9 @@ def r_average(symbol): return RAverage(symbol) -def size_average(symbol, f_a_dist=None): +def size_average( + symbol: pybamm.Symbol, f_a_dist: pybamm.Symbol | None = None +) -> pybamm.Symbol: """Convenience function for averaging over particle size R using the area-weighted particle-size distribution. @@ -339,7 +353,10 @@ def size_average(symbol, f_a_dist=None): return SizeAverage(symbol, f_a_dist) -def _sum_of_averages(symbol, average_function): +def _sum_of_averages( + symbol: pybamm.Addition | pybamm.Subtraction, + average_function: Callable[[pybamm.Symbol], pybamm.Symbol], +): if isinstance(symbol, pybamm.Addition): return average_function(symbol.left) + average_function(symbol.right) elif isinstance(symbol, pybamm.Subtraction): diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 3d70741785..d10844798b 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -1,6 +1,7 @@ # # Binary operator classes # +from __future__ import annotations import numbers import numpy as np @@ -10,15 +11,22 @@ import pybamm from pybamm.util import have_optional_dependency +from typing import Callable, cast -def _preprocess_binary(left, right): - if isinstance(left, numbers.Number): +# create type alias(s) +from pybamm.type_definitions import ChildSymbol, ChildValue, Numeric + + +def _preprocess_binary( + left: ChildSymbol, right: ChildSymbol +) -> tuple[pybamm.Symbol, pybamm.Symbol]: + if isinstance(left, (float, int, np.number)): left = pybamm.Scalar(left) elif isinstance(left, np.ndarray): if left.ndim > 1: raise ValueError("left must be a 1D array") left = pybamm.Vector(left) - if isinstance(right, numbers.Number): + if isinstance(right, (float, int, np.number)): right = pybamm.Scalar(right) elif isinstance(right, np.ndarray): if right.ndim > 1: @@ -60,8 +68,10 @@ class BinaryOperator(pybamm.Symbol): rhs child node (converted to :class:`Scalar` if Number) """ - def __init__(self, name, left, right): - left, right = _preprocess_binary(left, right) + def __init__( + self, name: str, left_child: ChildSymbol, right_child: ChildSymbol + ) -> None: + left, right = _preprocess_binary(left_child, right_child) domains = self.get_children_domains([left, right]) super().__init__(name, children=[left, right], domains=domains) @@ -118,7 +128,7 @@ def create_copy(self): return out - def _binary_new_copy(self, left, right): + def _binary_new_copy(self, left: ChildSymbol, right: ChildSymbol): """ Default behaviour for new_copy. This copies the behaviour of `_binary_evaluate`, but since `left` and `right` @@ -126,7 +136,13 @@ def _binary_new_copy(self, left, right): """ return self._binary_evaluate(left, right) - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """See :meth:`pybamm.Symbol.evaluate()`.""" left = self.left.evaluate(t, y, y_dot, inputs) right = self.right.evaluate(t, y, y_dot, inputs) @@ -148,7 +164,7 @@ def _binary_evaluate(self, left, right): f"{self.__class__} does not implement _binary_evaluate." ) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return self.left.evaluates_on_edges(dimension) or self.right.evaluates_on_edges( dimension @@ -188,11 +204,15 @@ class Power(BinaryOperator): A node in the expression tree representing a `**` power operator. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("**", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule base, exponent = self.orphans @@ -217,7 +237,11 @@ def _binary_jac(self, left_jac, right_jac): right * left_jac + left * pybamm.log(left) * right_jac ) - def _binary_evaluate(self, left, right): + def _binary_evaluate( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" # don't raise RuntimeWarning for NaNs with np.errstate(invalid="ignore"): @@ -229,19 +253,23 @@ class Addition(BinaryOperator): A node in the expression tree representing an addition operator. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("+", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) + self.right.diff(variable) - def _binary_jac(self, left_jac, right_jac): + def _binary_jac(self, left_jac: ChildValue, right_jac: ChildValue): """See :meth:`pybamm.BinaryOperator._binary_jac()`.""" return left_jac + right_jac - def _binary_evaluate(self, left, right): + def _binary_evaluate(self, left: ChildValue, right: ChildValue): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" return left + right @@ -251,12 +279,16 @@ class Subtraction(BinaryOperator): A node in the expression tree representing a subtraction operator. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("-", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) - self.right.diff(variable) @@ -264,7 +296,7 @@ def _binary_jac(self, left_jac, right_jac): """See :meth:`pybamm.BinaryOperator._binary_jac()`.""" return left_jac - right_jac - def _binary_evaluate(self, left, right): + def _binary_evaluate(self, left: ChildValue, right: ChildValue): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" return left - right @@ -276,12 +308,16 @@ class Multiplication(BinaryOperator): matrix multiplication (e.g. scipy.sparse.coo.coo_matrix) """ - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("*", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule left, right = self.orphans @@ -313,7 +349,11 @@ class MatrixMultiplication(BinaryOperator): A node in the expression tree representing a matrix multiplication operator. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("@", left, right) @@ -359,11 +399,15 @@ class Division(BinaryOperator): A node in the expression tree representing a division operator. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("/", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply quotient rule top, bottom = self.orphans @@ -403,11 +447,15 @@ class Inner(BinaryOperator): by a particular discretisation. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("inner product", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule left, right = self.orphans @@ -435,18 +483,22 @@ def _binary_evaluate(self, left, right): else: return left * right - def _binary_new_copy(self, left, right): + def _binary_new_copy( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.inner(left, right) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False -def inner(left, right): +def inner(left_child, right_child): """Return inner product of two symbols.""" - left, right = _preprocess_binary(left, right) + left, right = _preprocess_binary(left_child, right_child) # simplify multiply by scalar zero, being careful about shape if pybamm.is_scalar_zero(left): return pybamm.zeros_like(right) @@ -472,7 +524,11 @@ class Equality(BinaryOperator): nodes. Returns 1 if the two nodes evaluate to the same thing and 0 otherwise. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("==", left, right) @@ -496,7 +552,11 @@ def _binary_evaluate(self, left, right): else: return int(left == right) - def _binary_new_copy(self, left, right): + def _binary_new_copy( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.Equality(left, right) @@ -518,7 +578,12 @@ class _Heaviside(BinaryOperator): DISCONTINUITY event will automatically be added by the solver. """ - def __init__(self, name, left, right): + def __init__( + self, + name: str, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__(name, left, right) @@ -549,7 +614,11 @@ def _evaluate_for_shape(self): class EqualHeaviside(_Heaviside): """A heaviside function with equality (return 1 when left = right)""" - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("<=", left, right) @@ -567,7 +636,11 @@ def _binary_evaluate(self, left, right): class NotEqualHeaviside(_Heaviside): """A heaviside function without equality (return 0 when left = right)""" - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): super().__init__("<", left, right) def __str__(self): @@ -584,10 +657,14 @@ def _binary_evaluate(self, left, right): class Modulo(BinaryOperator): """Calculates the remainder of an integer division.""" - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): super().__init__("%", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule left, right = self.orphans @@ -622,14 +699,18 @@ def _binary_evaluate(self, left, right): class Minimum(BinaryOperator): """Returns the smaller of two objects.""" - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): super().__init__("minimum", left, right) def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return f"minimum({self.left!s}, {self.right!s})" - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" left, right = self.orphans return (left <= right) * left.diff(variable) + (left > right) * right.diff( @@ -646,7 +727,11 @@ def _binary_evaluate(self, left, right): # don't raise RuntimeWarning for NaNs return np.minimum(left, right) - def _binary_new_copy(self, left, right): + def _binary_new_copy( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.minimum(left, right) @@ -659,14 +744,18 @@ def _sympy_operator(self, left, right): class Maximum(BinaryOperator): """Returns the greater of two objects.""" - def __init__(self, left, right): + def __init__( + self, + left: ChildSymbol, + right: ChildSymbol, + ): super().__init__("maximum", left, right) def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return f"maximum({self.left!s}, {self.right!s})" - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" left, right = self.orphans return (left >= right) * left.diff(variable) + (left < right) * right.diff( @@ -683,7 +772,11 @@ def _binary_evaluate(self, left, right): # don't raise RuntimeWarning for NaNs return np.maximum(left, right) - def _binary_new_copy(self, left, right): + def _binary_new_copy( + self, + left: ChildSymbol, + right: ChildSymbol, + ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.maximum(left, right) @@ -693,10 +786,13 @@ def _sympy_operator(self, left, right): return sympy.Max(left, right) -def _simplify_elementwise_binary_broadcasts(left, right): - left, right = _preprocess_binary(left, right) +def _simplify_elementwise_binary_broadcasts( + left_child: ChildSymbol, + right_child: ChildSymbol, +) -> tuple[pybamm.Symbol, pybamm.Symbol]: + left, right = _preprocess_binary(left_child, right_child) - def unpack_broadcast_recursive(symbol): + def unpack_broadcast_recursive(symbol: pybamm.Symbol) -> pybamm.Symbol: if isinstance(symbol, pybamm.Broadcast): if symbol.child.domain == []: return symbol.orphans[0] @@ -721,7 +817,11 @@ def unpack_broadcast_recursive(symbol): return left, right -def _simplified_binary_broadcast_concatenation(left, right, operator): +def _simplified_binary_broadcast_concatenation( + left: pybamm.Symbol, + right: pybamm.Symbol, + operator: Callable, +) -> pybamm.Broadcast | None: """ Check if there are concatenations or broadcasts that we can commute the operator with @@ -759,9 +859,13 @@ def _simplified_binary_broadcast_concatenation(left, right, operator): return right._concatenation_new_copy( [operator(left, child) for child in right.orphans] ) + return None -def simplified_power(left, right): +def simplified_power( + left: ChildSymbol, + right: ChildSymbol, +): left, right = _simplify_elementwise_binary_broadcasts(left, right) # Check for Concatenations and Broadcasts @@ -803,7 +907,7 @@ def simplified_power(left, right): return pybamm.simplify_if_constant(pybamm.Power(left, right)) -def add(left, right): +def add(left: ChildSymbol, right: ChildSymbol): """ Note ---- @@ -891,7 +995,10 @@ def add(left, right): return pybamm.simplify_if_constant(Addition(left, right)) -def subtract(left, right): +def subtract( + left: ChildSymbol, + right: ChildSymbol, +): """ Note ---- @@ -973,7 +1080,10 @@ def subtract(left, right): return pybamm.simplify_if_constant(Subtraction(left, right)) -def multiply(left, right): +def multiply( + left: ChildSymbol, + right: ChildSymbol, +): left, right = _simplify_elementwise_binary_broadcasts(left, right) # Move constant to always be on the left @@ -1098,7 +1208,10 @@ def multiply(left, right): return Multiplication(left, right) -def divide(left, right): +def divide( + left: ChildSymbol, + right: ChildSymbol, +): left, right = _simplify_elementwise_binary_broadcasts(left, right) # anything divided by zero raises error @@ -1169,8 +1282,11 @@ def divide(left, right): return pybamm.simplify_if_constant(Division(left, right)) -def matmul(left, right): - left, right = _preprocess_binary(left, right) +def matmul( + left_child: ChildSymbol, + right_child: ChildSymbol, +): + left, right = _preprocess_binary(left_child, right_child) if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): return pybamm.zeros_like(MatrixMultiplication(left, right)) @@ -1228,16 +1344,19 @@ def matmul(left, right): return pybamm.simplify_if_constant(MatrixMultiplication(left, right)) -def minimum(left, right): +def minimum( + left: ChildSymbol, + right: ChildSymbol, +) -> pybamm.Symbol: """ Returns the smaller of two objects, possibly with a smoothing approximation. Not to be confused with :meth:`pybamm.min`, which returns min function of child. """ # Check for Concatenations and Broadcasts left, right = _simplify_elementwise_binary_broadcasts(left, right) - out = _simplified_binary_broadcast_concatenation(left, right, minimum) - if out is not None: - return out + concat_out = _simplified_binary_broadcast_concatenation(left, right, minimum) + if concat_out is not None: + return concat_out mode = pybamm.settings.min_max_mode k = pybamm.settings.min_max_smoothing @@ -1252,16 +1371,19 @@ def minimum(left, right): return pybamm.simplify_if_constant(out) -def maximum(left, right): +def maximum( + left: ChildSymbol, + right: ChildSymbol, +): """ Returns the larger of two objects, possibly with a smoothing approximation. Not to be confused with :meth:`pybamm.max`, which returns max function of child. """ # Check for Concatenations and Broadcasts left, right = _simplify_elementwise_binary_broadcasts(left, right) - out = _simplified_binary_broadcast_concatenation(left, right, maximum) - if out is not None: - return out + concat_out = _simplified_binary_broadcast_concatenation(left, right, maximum) + if concat_out is not None: + return concat_out mode = pybamm.settings.min_max_mode k = pybamm.settings.min_max_smoothing @@ -1276,15 +1398,15 @@ def maximum(left, right): return pybamm.simplify_if_constant(out) -def _heaviside(left, right, equal): +def _heaviside(left: ChildSymbol, right: ChildSymbol, equal): """return a :class:`EqualHeaviside` object, or a smooth approximation.""" # Check for Concatenations and Broadcasts left, right = _simplify_elementwise_binary_broadcasts(left, right) - out = _simplified_binary_broadcast_concatenation( + concat_out = _simplified_binary_broadcast_concatenation( left, right, functools.partial(_heaviside, equal=equal) ) - if out is not None: - return out + if concat_out is not None: + return concat_out if ( left.is_constant() @@ -1307,15 +1429,19 @@ def _heaviside(left, right, equal): # (i.e. no need for smoothing) if k == "exact" or (left.is_constant() and right.is_constant()): if equal is True: - out = pybamm.EqualHeaviside(left, right) + out: pybamm.EqualHeaviside = pybamm.EqualHeaviside(left, right) else: - out = pybamm.NotEqualHeaviside(left, right) + out: pybamm.NotEqualHeaviside = pybamm.NotEqualHeaviside(left, right) # type: ignore[no-redef] else: out = pybamm.sigmoid(left, right, k) return pybamm.simplify_if_constant(out) -def softminus(left, right, k): +def softminus( + left: pybamm.Symbol, + right: pybamm.Symbol, + k: float, +): """ Softminus approximation to the minimum function. k is the smoothing parameter, set by `pybamm.settings.min_max_smoothing`. The recommended value is k=10. @@ -1323,7 +1449,11 @@ def softminus(left, right, k): return pybamm.log(pybamm.exp(-k * left) + pybamm.exp(-k * right)) / -k -def softplus(left, right, k): +def softplus( + left: pybamm.Symbol, + right: pybamm.Symbol, + k: float, +): """ Softplus approximation to the maximum function. k is the smoothing parameter, set by `pybamm.settings.min_max_smoothing`. The recommended value is k=10. @@ -1349,7 +1479,11 @@ def smooth_max(left, right, k): return (pybamm.sqrt((left - right) ** 2 + sigma) + (left + right)) / 2 -def sigmoid(left, right, k): +def sigmoid( + left: pybamm.Symbol, + right: pybamm.Symbol, + k: float, +): """ Sigmoidal approximation to the heaviside function. k is the smoothing parameter, set by `pybamm.settings.heaviside_smoothing`. The recommended value is k=10. @@ -1359,7 +1493,11 @@ def sigmoid(left, right, k): return (1 + pybamm.tanh(k * (right - left))) / 2 -def source(left, right, boundary=False): +def source( + left: Numeric | pybamm.Symbol, + right: pybamm.Symbol, + boundary=False, +): """ A convenience function for creating (part of) an expression tree representing a source term. This is necessary for spatial methods where the mass matrix @@ -1375,7 +1513,7 @@ def source(left, right, boundary=False): Parameters ---------- - left : :class:`Symbol` + left : :class:`Symbol`, numeric The left child node, which represents the expression for the source term. right : :class:`Symbol` The right child node. This is the symbol whose boundary conditions are @@ -1390,6 +1528,9 @@ def source(left, right, boundary=False): if isinstance(left, numbers.Number): left = pybamm.PrimaryBroadcast(left, "current collector") + # force type cast for mypy + left = cast(pybamm.Symbol, left) + if left.domain != ["current collector"] or right.domain != ["current collector"]: raise pybamm.DomainError( f"""'source' only implemented in the 'current collector' domain, diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index d117341710..e17f375102 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -1,12 +1,19 @@ # # Unary operator classes and methods # -import numbers +from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix +from typing import cast import pybamm +from pybamm.type_definitions import ( + DomainType, + AuxiliaryDomainType, + DomainsType, + Numeric, +) class Broadcast(pybamm.SpatialOperator): @@ -29,7 +36,12 @@ class Broadcast(pybamm.SpatialOperator): name of the node """ - def __init__(self, child, domains, name=None): + def __init__( + self, + child: pybamm.Symbol, + domains: dict[str, list[str] | str], + name: str | None = None, + ): if name is None: name = "broadcast" super().__init__(name, child, domains=domains) @@ -41,7 +53,7 @@ def broadcasts_to_nodes(self): else: return False - def _sympy_operator(self, child): + def _sympy_operator(self, child: pybamm.Symbol): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" return child @@ -50,6 +62,10 @@ def _diff(self, variable): # Differentiate the child and broadcast the result in the same way return self._unary_new_copy(self.child.diff(variable)) + def reduce_one_dimension(self): # pragma: no cover + """Reduce the broadcast by one dimension.""" + raise NotImplementedError + def to_json(self): raise NotImplementedError( "pybamm.Broadcast: Serialisation is only implemented for discretised models" @@ -73,7 +89,7 @@ class PrimaryBroadcast(Broadcast): Parameters ---------- - child : :class:`Symbol` + child : :class:`Symbol`, numeric child node broadcast_domain : iterable of str Primary domain for broadcast. This will become the domain of the symbol @@ -81,10 +97,19 @@ class PrimaryBroadcast(Broadcast): name of the node """ - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: Numeric | pybamm.Symbol, + broadcast_domain: list[str] | str, + name: str | None = None, + ): # Convert child to scalar if it is a number - if isinstance(child, numbers.Number): + if isinstance(child, (float, int, np.number)): child = pybamm.Scalar(child) + + # cast child to Symbol for mypy + child = cast(pybamm.Symbol, child) + # Convert domain to list if it's a string if isinstance(broadcast_domain, str): broadcast_domain = [broadcast_domain] @@ -94,7 +119,7 @@ def __init__(self, child, broadcast_domain, name=None): self.broadcast_type = "primary to nodes" super().__init__(child, domains, name=name) - def check_and_set_domains(self, child, broadcast_domain): + def check_and_set_domains(self, child: pybamm.Symbol, broadcast_domain: list[str]): """See :meth:`Broadcast.check_and_set_domains`""" # Can only do primary broadcast from current collector to electrode, # particle-size or particle or from electrode to particle-size or particle. @@ -149,7 +174,7 @@ def check_and_set_domains(self, child, broadcast_domain): return domains - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" return self.__class__(child, self.broadcast_domain) @@ -170,7 +195,12 @@ def reduce_one_dimension(self): class PrimaryBroadcastToEdges(PrimaryBroadcast): """A primary broadcast onto the edges of the domain.""" - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: Numeric | pybamm.Symbol, + broadcast_domain: list[str] | str, + name: str | None = None, + ): name = name or "broadcast to edges" super().__init__(child, broadcast_domain, name) self.broadcast_type = "primary to edges" @@ -201,7 +231,12 @@ class SecondaryBroadcast(Broadcast): name of the node """ - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: pybamm.Symbol, + broadcast_domain: list[str] | str, + name: str | None = None, + ): # Convert domain to list if it's a string if isinstance(broadcast_domain, str): broadcast_domain = [broadcast_domain] @@ -211,7 +246,7 @@ def __init__(self, child, broadcast_domain, name=None): self.broadcast_type = "secondary to nodes" super().__init__(child, domains, name=name) - def check_and_set_domains(self, child, broadcast_domain): + def check_and_set_domains(self, child: pybamm.Symbol, broadcast_domain: list[str]): """See :meth:`Broadcast.check_and_set_domains`""" if child.domain == []: raise TypeError( @@ -273,7 +308,7 @@ def check_and_set_domains(self, child, broadcast_domain): return domains - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" return SecondaryBroadcast(child, self.broadcast_domain) @@ -294,7 +329,12 @@ def reduce_one_dimension(self): class SecondaryBroadcastToEdges(SecondaryBroadcast): """A secondary broadcast onto the edges of a domain.""" - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: pybamm.Symbol, + broadcast_domain: list[str] | str, + name: str | None = None, + ): name = name or "broadcast to edges" super().__init__(child, broadcast_domain, name) self.broadcast_type = "secondary to edges" @@ -325,7 +365,12 @@ class TertiaryBroadcast(Broadcast): name of the node """ - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: pybamm.Symbol, + broadcast_domain: list[str] | str, + name: str | None = None, + ): # Convert domain to list if it's a string if isinstance(broadcast_domain, str): broadcast_domain = [broadcast_domain] @@ -335,7 +380,9 @@ def __init__(self, child, broadcast_domain, name=None): self.broadcast_type = "tertiary to nodes" super().__init__(child, domains, name=name) - def check_and_set_domains(self, child, broadcast_domain): + def check_and_set_domains( + self, child: pybamm.Symbol, broadcast_domain: list[str] | str + ): """See :meth:`Broadcast.check_and_set_domains`""" if child.domains["secondary"] == []: raise TypeError( @@ -382,7 +429,7 @@ def check_and_set_domains(self, child, broadcast_domain): return domains - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" return self.__class__(child, self.broadcast_domain) @@ -403,7 +450,12 @@ def reduce_one_dimension(self): class TertiaryBroadcastToEdges(TertiaryBroadcast): """A tertiary broadcast onto the edges of a domain.""" - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: pybamm.Symbol, + broadcast_domain: list[str] | str, + name: str | None = None, + ): name = name or "broadcast to edges" super().__init__(child, broadcast_domain, name) self.broadcast_type = "tertiary to edges" @@ -417,15 +469,17 @@ class FullBroadcast(Broadcast): def __init__( self, - child, - broadcast_domain=None, - auxiliary_domains=None, - broadcast_domains=None, - name=None, + child_input: Numeric | pybamm.Symbol, + broadcast_domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + broadcast_domains: DomainsType = None, + name: str | None = None, ): # Convert child to scalar if it is a number - if isinstance(child, numbers.Number): - child = pybamm.Scalar(child) + if isinstance(child_input, (float, int, np.number)): + child: pybamm.Scalar = pybamm.Scalar(child_input) + else: + child: pybamm.Symbol = child_input # type: ignore[no-redef] if isinstance(auxiliary_domains, str): auxiliary_domains = {"secondary": auxiliary_domains} @@ -438,7 +492,7 @@ def __init__( self.broadcast_type = "full to nodes" super().__init__(child, domains, name=name) - def check_and_set_domains(self, child, broadcast_domains): + def check_and_set_domains(self, child: pybamm.Symbol, broadcast_domains: dict): """See :meth:`Broadcast.check_and_set_domains`""" if broadcast_domains["primary"] == []: raise pybamm.DomainError( @@ -489,11 +543,11 @@ class FullBroadcastToEdges(FullBroadcast): def __init__( self, - child, - broadcast_domain=None, - auxiliary_domains=None, - broadcast_domains=None, - name=None, + child: Numeric | pybamm.Symbol, + broadcast_domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + broadcast_domains: DomainsType = None, + name: str | None = None, ): name = name or "broadcast to edges" super().__init__( @@ -520,7 +574,7 @@ def reduce_one_dimension(self): ) -def full_like(symbols, fill_value): +def full_like(symbols: tuple[pybamm.Symbol, ...], fill_value: float) -> pybamm.Symbol: """ Returns an array with the same shape and domains as the sum of the input symbols, with a constant value given by `fill_value`. @@ -543,18 +597,18 @@ def full_like(symbols, fill_value): return pybamm.Scalar(fill_value) try: shape = sum_symbol.shape - # use vector or matrix - if shape[1] == 1: - array_type = pybamm.Vector - else: - array_type = pybamm.Matrix + # return dense array, except for a matrix of zeros if shape[1] != 1 and fill_value == 0: entries = csr_matrix(shape) else: entries = fill_value * np.ones(shape) - return array_type(entries, domains=sum_symbol.domains) + # use vector or matrix + if shape[1] == 1: + return pybamm.Vector(entries, domains=sum_symbol.domains) + else: + return pybamm.Matrix(entries, domains=sum_symbol.domains) except NotImplementedError: if ( @@ -571,7 +625,7 @@ def full_like(symbols, fill_value): return FullBroadcast(fill_value, broadcast_domains=sum_symbol.domains) -def zeros_like(*symbols): +def zeros_like(*symbols: pybamm.Symbol): """ Returns an array with the same shape and domains as the sum of the input symbols, with each entry equal to zero. @@ -584,7 +638,7 @@ def zeros_like(*symbols): return full_like(symbols, 0) -def ones_like(*symbols): +def ones_like(*symbols: pybamm.Symbol): """ Returns an array with the same shape and domains as the sum of the input symbols, with each entry equal to one. diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 40cfe617ac..06718c311d 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -1,11 +1,13 @@ # # Concatenation classes # +from __future__ import annotations import copy from collections import defaultdict import numpy as np from scipy.sparse import issparse, vstack +from typing import Sequence import pybamm from pybamm.util import have_optional_dependency @@ -21,7 +23,13 @@ class Concatenation(pybamm.Symbol): The symbols to concatenate """ - def __init__(self, *children, name=None, check_domain=True, concat_fun=None): + def __init__( + self, + *children: pybamm.Symbol, + name: str | None = None, + check_domain=True, + concat_fun=None, + ): # The second condition checks whether this is the base Concatenation class # or a subclass of Concatenation # (ConcatenationVariable, NumpyConcatenation, ...) @@ -44,13 +52,15 @@ def __init__(self, *children, name=None, check_domain=True, concat_fun=None): super().__init__(name, children, domains=domains) @classmethod - def _from_json(cls, *children, name, domains, concat_fun=None): + def _from_json(cls, snippet: dict): """Creates a new Concatenation instance from a json object""" instance = cls.__new__(cls) - instance.concatenation_function = concat_fun + instance.concatenation_function = snippet["concat_fun"] - super(Concatenation, instance).__init__(name, children, domains=domains) + super(Concatenation, instance).__init__( + snippet["name"], tuple(snippet["children"]), domains=snippet["domains"] + ) return instance @@ -62,7 +72,7 @@ def __str__(self): out = out[:-2] + ")" return out - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" children_diffs = [child.diff(variable) for child in self.children] if len(children_diffs) == 1: @@ -72,9 +82,9 @@ def _diff(self, variable): return diff - def get_children_domains(self, children): + def get_children_domains(self, children: Sequence[pybamm.Symbol]): # combine domains from children - domain = [] + domain: list = [] for child in children: if not isinstance(child, pybamm.Symbol): raise TypeError(f"{child} is not a pybamm symbol") @@ -101,19 +111,22 @@ def get_children_domains(self, children): return domains - def _concatenation_evaluate(self, children_eval): + def _concatenation_evaluate(self, children_eval: list[np.ndarray]): """See :meth:`Concatenation._concatenation_evaluate()`.""" if len(children_eval) == 0: return np.array([]) else: return self.concatenation_function(children_eval) - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """See :meth:`pybamm.Symbol.evaluate()`.""" - children = self.children - children_eval = [None] * len(children) - for idx, child in enumerate(children): - children_eval[idx] = child.evaluate(t, y, y_dot, inputs) + children_eval = [child.evaluate(t, y, y_dot, inputs) for child in self.children] return self._concatenation_evaluate(children_eval) def create_copy(self): @@ -180,7 +193,7 @@ class NumpyConcatenation(Concatenation): The equations to concatenate """ - def __init__(self, *children): + def __init__(self, *children: pybamm.Symbol): children = list(children) # Turn objects that evaluate to scalars to objects that evaluate to vectors, # so that we can concatenate them @@ -197,12 +210,11 @@ def __init__(self, *children): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.Concatenation._from_json()`.""" - instance = super()._from_json( - *snippet["children"], - name="numpy_concatenation", - domains=snippet["domains"], - concat_fun=np.concatenate, - ) + + snippet["name"] = "numpy_concatenation" + snippet["concat_fun"] = np.concatenate + + instance = super()._from_json(snippet) return instance @@ -233,7 +245,7 @@ class DomainConcatenation(Concatenation): children : iterable of :class:`pybamm.Symbol` The symbols to concatenate - full_mesh : :class:`pybamm.BaseMesh` + full_mesh : :class:`pybamm.Mesh` The underlying mesh for discretisation, used to obtain the number of mesh points in each domain. @@ -242,7 +254,12 @@ class DomainConcatenation(Concatenation): from `copy_this`. `mesh` is not used in this case """ - def __init__(self, children, full_mesh, copy_this=None): + def __init__( + self, + children: Sequence[pybamm.Symbol], + full_mesh: pybamm.Mesh, + copy_this: pybamm.DomainConcatenation | None = None, + ): # Convert any constant symbols in children to a Vector of the right size for # concatenation children = list(children) @@ -277,11 +294,11 @@ def __init__(self, children, full_mesh, copy_this=None): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.Concatenation._from_json()`.""" - instance = super()._from_json( - *snippet["children"], - name="domain_concatenation", - domains=snippet["domains"], - ) + + snippet["name"] = "domain_concatenation" + snippet["concat_fun"] = None + + instance = super()._from_json(snippet) def repack_defaultDict(slices): slices = defaultdict(list, slices) @@ -299,7 +316,7 @@ def repack_defaultDict(slices): return instance - def _get_auxiliary_domain_repeats(self, auxiliary_domains): + def _get_auxiliary_domain_repeats(self, auxiliary_domains: dict) -> int: """Helper method to read the 'auxiliary_domain' meshes.""" mesh_pts = 1 for level, dom in auxiliary_domains.items(): @@ -311,7 +328,7 @@ def _get_auxiliary_domain_repeats(self, auxiliary_domains): def full_mesh(self): return self._full_mesh - def create_slices(self, node): + def create_slices(self, node: pybamm.Symbol) -> defaultdict: slices = defaultdict(list) start = 0 end = 0 @@ -328,7 +345,7 @@ def create_slices(self, node): start = end return slices - def _concatenation_evaluate(self, children_eval): + def _concatenation_evaluate(self, children_eval: list[np.ndarray]): """See :meth:`Concatenation._concatenation_evaluate()`.""" # preallocate vector vector = np.empty((self._size, 1)) @@ -357,7 +374,7 @@ def _concatenation_jac(self, children_jacs): jacs.append(pybamm.Index(child_jac, child_slice[i])) return SparseStack(*jacs) - def _concatenation_new_copy(self, children): + def _concatenation_new_copy(self, children: list[pybamm.Symbol]): """See :meth:`pybamm.Symbol.new_copy()`.""" new_symbol = simplified_domain_concatenation( children, self.full_mesh, copy_this=self @@ -464,13 +481,13 @@ def __init__(self, *children): self.print_name = print_name -def substrings(s): +def substrings(s: str): for i in range(len(s)): for j in range(i, len(s)): yield s[i : j + 1] -def intersect(s1, s2): +def intersect(s1: str, s2: str): # find all the common strings between two strings all_intersects = set(substrings(s1)) & set(substrings(s2)) # intersect is the longest such intercept @@ -536,24 +553,29 @@ def numpy_concatenation(*children): return simplified_numpy_concatenation(*children) -def simplified_domain_concatenation(children, mesh, copy_this=None): +def simplified_domain_concatenation( + children: list[pybamm.Symbol], + mesh: pybamm.Mesh, + copy_this: DomainConcatenation | None = None, +): """Perform simplifications on a domain concatenation.""" # Create the DomainConcatenation to read domain and child domain concat = DomainConcatenation(children, mesh, copy_this=copy_this) # Simplify Concatenation of StateVectors to a single StateVector # The sum of the evalation arrays of the StateVectors must be exactly 1 if all(isinstance(child, pybamm.StateVector) for child in children): - longest_eval_array = len(children[-1]._evaluation_array) + sv_children: list[pybamm.StateVector] = children # type: ignore[assignment] + longest_eval_array = len(sv_children[-1]._evaluation_array) eval_arrays = {} - for child in children: + for child in sv_children: eval_arrays[child] = np.concatenate( [ child.evaluation_array, np.zeros(longest_eval_array - len(child.evaluation_array)), ] ) - first_start = children[0].y_slices[0].start - last_stop = children[-1].y_slices[-1].stop + first_start = sv_children[0].y_slices[0].start + last_stop = sv_children[-1].y_slices[-1].stop if all( sum(array for array in eval_arrays.values())[first_start:last_stop] == 1 ): @@ -564,7 +586,7 @@ def simplified_domain_concatenation(children, mesh, copy_this=None): return pybamm.simplify_if_constant(concat) -def domain_concatenation(children, mesh): +def domain_concatenation(children: list[pybamm.Symbol], mesh: pybamm.Mesh): """Helper function to create domain concatenations.""" # TODO: add option to turn off simplifications return simplified_domain_concatenation(children, mesh) diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index f95f190b43..72c9d4074a 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -1,11 +1,12 @@ # # Function classes and methods # -import numbers +from __future__ import annotations import numpy as np from scipy import special -from typing import Callable +from typing import Sequence, Callable +from typing_extensions import TypeVar import pybamm from pybamm.util import have_optional_dependency @@ -32,16 +33,16 @@ class Function(pybamm.Symbol): def __init__( self, - function, - *children, - name=None, - derivative="autograd", - differentiated_function=None, + function: Callable, + *children: pybamm.Symbol, + name: str | None = None, + derivative: str | None = "autograd", + differentiated_function: Callable | None = None, ): # Turn numbers into scalars children = list(children) for idx, child in enumerate(children): - if isinstance(child, numbers.Number): + if isinstance(child, (float, int, np.number)): children[idx] = pybamm.Scalar(child) if name is not None: @@ -67,13 +68,13 @@ def __str__(self): out = out[:-2] + ")" return out - def diff(self, variable): + def diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol.diff()`.""" if variable == self: return pybamm.Scalar(1) else: children = self.orphans - partial_derivatives = [None] * len(children) + partial_derivatives: list[None | pybamm.Symbol] = [None] * len(children) for i, child in enumerate(self.children): # if variable appears in the function, differentiate # function, and apply chain rule @@ -87,11 +88,11 @@ def diff(self, variable): derivative = sum(partial_derivatives) if derivative == 0: - derivative = pybamm.Scalar(0) + return pybamm.Scalar(0) return derivative - def _function_diff(self, children, idx): + def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float): """ Derivative with respect to child number 'idx'. See :meth:`pybamm.Symbol._diff()`. @@ -115,7 +116,7 @@ def _function_diff(self, children, idx): else: # keep using "derivative" as derivative return pybamm.Function( - self.function.derivative(), + self.function.derivative(), # type: ignore[attr-defined] *children, derivative="derivative", differentiated_function=self.function, @@ -142,14 +143,20 @@ def _function_jac(self, children_jacs): return jacobian - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """See :meth:`pybamm.Symbol.evaluate()`.""" evaluated_children = [ child.evaluate(t, y, y_dot, inputs) for child in self.children ] return self._function_evaluate(evaluated_children) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return any(child.evaluates_on_edges(dimension) for child in self.children) @@ -173,7 +180,7 @@ def create_copy(self): children_copy = [child.new_copy() for child in self.children] return self._function_new_copy(children_copy) - def _function_new_copy(self, children): + def _function_new_copy(self, children: list) -> Function: """ Returns a new copy of the function. @@ -225,22 +232,6 @@ def _from_json(cls, snippet): ) -def simplified_function(func_class, child): - """ - Simplifications implemented before applying the function. - Currently only implemented for one-child functions. - """ - if isinstance(child, pybamm.Broadcast): - # Move the function inside the broadcast - # Apply recursively - func_child_not_broad = pybamm.simplify_if_constant( - simplified_function(func_class, child.orphans[0]) - ) - return child._unary_new_copy(func_child_not_broad) - else: - return pybamm.simplify_if_constant(func_class(child)) - - class SpecificFunction(Function): """ Parent class for the specific functions, which implement their own `diff` @@ -254,11 +245,11 @@ class SpecificFunction(Function): The child to apply the function to """ - def __init__(self, function, child): + def __init__(self, function: Callable, child: pybamm.Symbol): super().__init__(function, child) @classmethod - def _from_json(cls, function: Callable, snippet: dict): + def _from_json(cls, snippet: dict): """ Reconstructs a SpecificFunction instance during deserialisation of a JSON file. @@ -272,7 +263,9 @@ def _from_json(cls, function: Callable, snippet: dict): instance = cls.__new__(cls) - super(SpecificFunction, instance).__init__(function, snippet["children"][0]) + super(SpecificFunction, instance).__init__( + snippet["function"], snippet["children"][0] + ) return instance @@ -301,6 +294,25 @@ def to_json(self): return json_dict +SF = TypeVar("SF", bound=SpecificFunction) + + +def simplified_function(func_class: type[SF], child: pybamm.Symbol): + """ + Simplifications implemented before applying the function. + Currently only implemented for one-child functions. + """ + if isinstance(child, pybamm.Broadcast): + # Move the function inside the broadcast + # Apply recursively + func_child_not_broad = pybamm.simplify_if_constant( + simplified_function(func_class, child.orphans[0]) + ) + return child._unary_new_copy(func_child_not_broad) + else: + return pybamm.simplify_if_constant(func_class(child)) # type: ignore[call-arg, arg-type] + + class Arcsinh(SpecificFunction): """Arcsinh function.""" @@ -310,7 +322,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.arcsinh, snippet) + snippet["function"] = np.arcsinh + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -323,7 +336,7 @@ def _sympy_operator(self, child): return sympy.asinh(child) -def arcsinh(child): +def arcsinh(child: pybamm.Symbol): """Returns arcsinh function of child.""" return simplified_function(Arcsinh, child) @@ -337,7 +350,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.arctan, snippet) + snippet["function"] = np.arctan + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -350,7 +364,7 @@ def _sympy_operator(self, child): return sympy.atan(child) -def arctan(child): +def arctan(child: pybamm.Symbol): """Returns hyperbolic tan function of child.""" return simplified_function(Arctan, child) @@ -364,7 +378,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.cos, snippet) + snippet["function"] = np.cos + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -372,7 +387,7 @@ def _function_diff(self, children, idx): return -sin(children[0]) -def cos(child): +def cos(child: pybamm.Symbol): """Returns cosine function of child.""" return simplified_function(Cos, child) @@ -386,7 +401,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.cosh, snippet) + snippet["function"] = np.cosh + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -394,7 +410,7 @@ def _function_diff(self, children, idx): return sinh(children[0]) -def cosh(child): +def cosh(child: pybamm.Symbol): """Returns hyperbolic cosine function of child.""" return simplified_function(Cosh, child) @@ -408,7 +424,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(special.erf, snippet) + snippet["function"] = special.erf + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -416,12 +433,12 @@ def _function_diff(self, children, idx): return 2 / np.sqrt(np.pi) * exp(-(children[0] ** 2)) -def erf(child): +def erf(child: pybamm.Symbol): """Returns error function of child.""" return simplified_function(Erf, child) -def erfc(child): +def erfc(child: pybamm.Symbol): """Returns complementary error function of child.""" return 1 - simplified_function(Erf, child) @@ -435,7 +452,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.exp, snippet) + snippet["function"] = np.exp + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -443,7 +461,7 @@ def _function_diff(self, children, idx): return exp(children[0]) -def exp(child): +def exp(child: pybamm.Symbol): """Returns exponential function of child.""" return simplified_function(Exp, child) @@ -457,7 +475,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.log, snippet) + snippet["function"] = np.log + instance = super()._from_json(snippet) return instance def _function_evaluate(self, evaluated_children): @@ -479,7 +498,7 @@ def log(child, base="e"): return log_child / np.log(base) -def log10(child): +def log10(child: pybamm.Symbol): """Returns logarithmic function of child, with base 10.""" return log(child, base=10) @@ -493,7 +512,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.max, snippet) + snippet["function"] = np.max + instance = super()._from_json(snippet) return instance def _evaluate_for_shape(self): @@ -502,7 +522,7 @@ def _evaluate_for_shape(self): return np.nan * np.ones((1, 1)) -def max(child): +def max(child: pybamm.Symbol): """ Returns max function of child. Not to be confused with :meth:`pybamm.maximum`, which returns the larger of two objects. @@ -519,7 +539,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.min, snippet) + snippet["function"] = np.min + instance = super()._from_json(snippet) return instance def _evaluate_for_shape(self): @@ -528,7 +549,7 @@ def _evaluate_for_shape(self): return np.nan * np.ones((1, 1)) -def min(child): +def min(child: pybamm.Symbol): """ Returns min function of child. Not to be confused with :meth:`pybamm.minimum`, which returns the smaller of two objects. @@ -536,7 +557,7 @@ def min(child): return pybamm.simplify_if_constant(Min(child)) -def sech(child): +def sech(child: pybamm.Symbol): """Returns hyperbolic sec function of child.""" return 1 / simplified_function(Cosh, child) @@ -550,7 +571,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.sin, snippet) + snippet["function"] = np.sin + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -558,7 +580,7 @@ def _function_diff(self, children, idx): return cos(children[0]) -def sin(child): +def sin(child: pybamm.Symbol): """Returns sine function of child.""" return simplified_function(Sin, child) @@ -572,7 +594,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.sinh, snippet) + snippet["function"] = np.sinh + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -580,7 +603,7 @@ def _function_diff(self, children, idx): return cosh(children[0]) -def sinh(child): +def sinh(child: pybamm.Symbol): """Returns hyperbolic sine function of child.""" return simplified_function(Sinh, child) @@ -594,7 +617,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.sqrt, snippet) + snippet["function"] = np.sqrt + instance = super()._from_json(snippet) return instance def _function_evaluate(self, evaluated_children): @@ -607,7 +631,7 @@ def _function_diff(self, children, idx): return 1 / (2 * sqrt(children[0])) -def sqrt(child): +def sqrt(child: pybamm.Symbol): """Returns square root function of child.""" return simplified_function(Sqrt, child) @@ -621,7 +645,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.tanh, snippet) + snippet["function"] = np.tanh + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -629,6 +654,6 @@ def _function_diff(self, children, idx): return sech(children[0]) ** 2 -def tanh(child): +def tanh(child: pybamm.Symbol): """Returns hyperbolic tan function of child.""" return simplified_function(Tanh, child) diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index dccb627eed..0dca6dba46 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -1,8 +1,13 @@ # # IndependentVariable class # +from __future__ import annotations +import sympy +import numpy as np + import pybamm from pybamm.util import have_optional_dependency +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType KNOWN_COORD_SYS = ["cartesian", "cylindrical polar", "spherical polar"] @@ -28,28 +33,30 @@ class IndependentVariable(pybamm.Symbol): deprecated. """ - def __init__(self, name, domain=None, auxiliary_domains=None, domains=None): + def __init__( + self, + name: str, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, + ) -> None: super().__init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains ) @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - - instance.__init__(snippet["name"], domains=snippet["domains"]) - - return instance + return cls(snippet["name"], domains=snippet["domains"]) def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) - def _jac(self, variable): + def _jac(self, variable) -> pybamm.Scalar: """See :meth:`pybamm.Symbol._jac()`.""" return pybamm.Scalar(0) - def to_equation(self): + def to_equation(self) -> sympy.Symbol: """Convert the node and its subtree into a SymPy equation.""" sympy = have_optional_dependency("sympy") if self.print_name is not None: @@ -68,17 +75,19 @@ def __init__(self): @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - - instance.__init__() - - return instance + return cls() def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" return Time() - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if t is None: raise ValueError("t must be provided") @@ -118,8 +127,13 @@ class SpatialVariable(IndependentVariable): """ def __init__( - self, name, domain=None, auxiliary_domains=None, domains=None, coord_sys=None - ): + self, + name: str, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, + coord_sys=None, + ) -> None: self.coord_sys = coord_sys super().__init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains @@ -174,8 +188,13 @@ class SpatialVariableEdge(SpatialVariable): """ def __init__( - self, name, domain=None, auxiliary_domains=None, domains=None, coord_sys=None - ): + self, + name: str, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, + coord_sys=None, + ) -> None: super().__init__(name, domain, auxiliary_domains, domains, coord_sys) def _evaluates_on_edges(self, dimension): diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index 2680276c60..a1ee00a47b 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -1,11 +1,14 @@ # # Parameter classes # +from __future__ import annotations import numbers import numpy as np import scipy.sparse import pybamm +from pybamm.type_definitions import DomainType + class InputParameter(pybamm.Symbol): """ @@ -25,7 +28,12 @@ class InputParameter(pybamm.Symbol): The size of the input parameter expected, defaults to 1 (scalar input) """ - def __init__(self, name, domain=None, expected_size=None): + def __init__( + self, + name: str, + domain: DomainType = None, + expected_size: int | None = None, + ) -> None: # Expected size defaults to 1 if no domain else None (gets set later) if expected_size is None: if domain is None: @@ -37,17 +45,13 @@ def __init__(self, name, domain=None, expected_size=None): @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - - instance.__init__( + return cls( snippet["name"], domain=snippet["domain"], expected_size=snippet["expected_size"], ) - return instance - - def create_copy(self): + def create_copy(self) -> pybamm.InputParameter: """See :meth:`pybamm.Symbol.new_copy()`.""" new_input_parameter = InputParameter( self.name, self.domain, expected_size=self._expected_size @@ -66,7 +70,7 @@ def _evaluate_for_shape(self): else: return np.nan * np.ones((self._expected_size, 1)) - def _jac(self, variable): + def _jac(self, variable: pybamm.StateVector) -> pybamm.Matrix: """See :meth:`pybamm.Symbol._jac()`.""" n_variable = variable.evaluation_array.count(True) nan_vector = self._evaluate_for_shape() @@ -77,7 +81,13 @@ def _jac(self, variable): zero_matrix = scipy.sparse.csr_matrix((n_self, n_variable)) return pybamm.Matrix(zero_matrix) - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): # inputs should be a dictionary # convert 'None' to empty dictionary for more informative error if inputs is None: diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index f709c03a62..7d0b0c6eb2 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -1,8 +1,11 @@ # # Interpolating class # +from __future__ import annotations import numpy as np from scipy import interpolate +from typing import Sequence + import pybamm @@ -40,13 +43,13 @@ class Interpolant(pybamm.Function): def __init__( self, - x, - y, - children, - name=None, - interpolator="linear", - extrapolate=True, - entries_string=None, + x: np.ndarray | Sequence[np.ndarray], + y: np.ndarray, + children: Sequence[pybamm.Symbol] | pybamm.Time, + name: str | None = None, + interpolator: str | None = "linear", + extrapolate: bool = True, + entries_string: str | None = None, ): # Check interpolator is valid if interpolator not in ["linear", "cubic", "pchip"]: @@ -92,7 +95,7 @@ def __init__( x1 = x[0] else: x1 = x - x = [x] + x: list[np.ndarray] = [x] # type: ignore[no-redef] x2 = None if x1.shape[0] != y.shape[0]: raise ValueError( @@ -118,14 +121,14 @@ def __init__( self.dimension = 1 if interpolator == "linear": if extrapolate is False: - fill_value = np.nan + fill_value_1: float | str = np.nan elif extrapolate is True: - fill_value = "extrapolate" + fill_value_1 = "extrapolate" interpolating_function = interpolate.interp1d( x1, y.T, bounds_error=False, - fill_value=fill_value, + fill_value=fill_value_1, ) elif interpolator == "cubic": interpolating_function = interpolate.CubicSpline( @@ -196,15 +199,14 @@ def __init__( @classmethod def _from_json(cls, snippet: dict): """Create an Interpolant object from JSON data""" - instance = cls.__new__(cls) + + x1 = [] if len(snippet["x"]) == 1: - x = [np.array(x) for x in snippet["x"]] - else: - x = tuple(np.array(x) for x in snippet["x"]) + x1 = [np.array(x) for x in snippet["x"]] - instance.__init__( - x, + return cls( + x1 if x1 else tuple(np.array(x) for x in snippet["x"]), np.array(snippet["y"]), snippet["children"], name=snippet["name"], @@ -212,8 +214,6 @@ def _from_json(cls, snippet: dict): extrapolate=snippet["extrapolate"], ) - return instance - @property def entries_string(self): return self._entries_string diff --git a/pybamm/expression_tree/matrix.py b/pybamm/expression_tree/matrix.py index 8b36bca53e..435d06d84f 100644 --- a/pybamm/expression_tree/matrix.py +++ b/pybamm/expression_tree/matrix.py @@ -1,10 +1,12 @@ # # Matrix class # +from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, issparse import pybamm +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType class Matrix(pybamm.Array): @@ -14,13 +16,13 @@ class Matrix(pybamm.Array): def __init__( self, - entries, - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - entries_string=None, - ): + entries: np.ndarray | list[float] | csr_matrix, + name: str | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, + entries_string: str | None = None, + ) -> None: if isinstance(entries, list): entries = np.array(entries) if name is None: diff --git a/pybamm/expression_tree/operations/convert_to_casadi.py b/pybamm/expression_tree/operations/convert_to_casadi.py index d29ae994f2..084b587721 100644 --- a/pybamm/expression_tree/operations/convert_to_casadi.py +++ b/pybamm/expression_tree/operations/convert_to_casadi.py @@ -1,6 +1,8 @@ # # Convert a PyBaMM expression tree to a CasADi expression tree # +from __future__ import annotations + import pybamm import casadi import numpy as np @@ -13,7 +15,14 @@ def __init__(self, casadi_symbols=None): pybamm.citations.register("Andersson2019") - def convert(self, symbol, t, y, y_dot, inputs): + def convert( + self, + symbol: pybamm.Symbol, + t: casadi.MX, + y: casadi.MX, + y_dot: casadi.MX, + inputs: dict | None, + ) -> casadi.MX: """ This function recurses down the tree, converting the PyBaMM expression tree to a CasADi expression tree diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index d7a43486f0..3582db13aa 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -1,8 +1,10 @@ # # Write a symbol to python # +from __future__ import annotations import numbers from collections import OrderedDict +from numpy.typing import ArrayLike import numpy as np import scipy.sparse @@ -38,7 +40,9 @@ class JaxCooMatrix: where x is the number of rows, and y the number of columns of the matrix """ - def __init__(self, row, col, data, shape): + def __init__( + self, row: ArrayLike, col: ArrayLike, data: ArrayLike, shape: tuple[int, int] + ): if not pybamm.have_jax(): # pragma: no cover raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver" @@ -68,7 +72,7 @@ def dot_product(self, b): result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype) return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col]) - def scalar_multiply(self, b): + def scalar_multiply(self, b: float): """ multiply of matrix with a scalar b @@ -91,7 +95,7 @@ def __matmul__(self, b): return self.dot_product(b) -def create_jax_coo_matrix(value): +def create_jax_coo_matrix(value: scipy.sparse): """ Creates a JaxCooMatrix from a scipy.sparse matrix @@ -131,7 +135,12 @@ def is_scalar(arg): return np.all(np.array(arg.shape) == 1) -def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): +def find_symbols( + symbol: pybamm.Symbol, + constant_symbols: OrderedDict, + variable_symbols: OrderedDict, + output_jax=False, +): """ This function converts an expression tree to a dictionary of node id's and strings specifying valid python code to calculate that nodes value, given y and t. @@ -361,7 +370,9 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): variable_symbols[symbol.id] = symbol_str -def to_python(symbol, debug=False, output_jax=False): +def to_python( + symbol: pybamm.Symbol, debug=False, output_jax=False +) -> tuple[OrderedDict, str]: """ This function converts an expression tree into a dict of constant input values, and valid python code that acts like the tree's :func:`pybamm.Symbol.evaluate` function @@ -387,8 +398,8 @@ def to_python(symbol, debug=False, output_jax=False): operations are used """ - constant_values = OrderedDict() - variable_symbols = OrderedDict() + constant_values: OrderedDict = OrderedDict() + variable_symbols: OrderedDict = OrderedDict() find_symbols(symbol, constant_values, variable_symbols, output_jax) line_format = "{} = {}" @@ -427,7 +438,7 @@ class EvaluatorPython: """ - def __init__(self, symbol): + def __init__(self, symbol: pybamm.Symbol): constants, python_str = pybamm.to_python(symbol, debug=False) # extract constants in generated function @@ -519,7 +530,7 @@ class EvaluatorJax: """ - def __init__(self, symbol): + def __init__(self, symbol: pybamm.Symbol): if not pybamm.have_jax(): # pragma: no cover raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver" @@ -585,7 +596,8 @@ def __init__(self, symbol): self._static_argnums = tuple(static_argnums) self._jit_evaluate = jax.jit( - self._evaluate_jax, static_argnums=self._static_argnums + self._evaluate_jax, # type:ignore[attr-defined] + static_argnums=self._static_argnums, ) def get_jacobian(self): diff --git a/pybamm/expression_tree/operations/jacobian.py b/pybamm/expression_tree/operations/jacobian.py index a191e2c74d..6348e1fdc0 100644 --- a/pybamm/expression_tree/operations/jacobian.py +++ b/pybamm/expression_tree/operations/jacobian.py @@ -1,6 +1,7 @@ # # Calculate the Jacobian of a symbol # +from __future__ import annotations import pybamm @@ -18,11 +19,15 @@ class Jacobian: whether or not the Jacobian clears the domain (default True) """ - def __init__(self, known_jacs=None, clear_domain=True): + def __init__( + self, + known_jacs: dict[pybamm.Symbol, pybamm.Symbol] | None = None, + clear_domain: bool = True, + ): self._known_jacs = known_jacs or {} self._clear_domain = clear_domain - def jac(self, symbol, variable): + def jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol) -> pybamm.Symbol: """ This function recurses down the tree, computing the Jacobian using the Jacobians defined in classes derived from pybamm.Symbol. E.g. the @@ -52,7 +57,7 @@ def jac(self, symbol, variable): self._known_jacs[symbol] = jac return jac - def _jac(self, symbol, variable): + def _jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol): """See :meth:`Jacobian.jac()`.""" if isinstance(symbol, pybamm.BinaryOperator): @@ -64,12 +69,12 @@ def _jac(self, symbol, variable): jac = symbol._binary_jac(left_jac, right_jac) elif isinstance(symbol, pybamm.UnaryOperator): - child_jac = self.jac(symbol.child, variable) + child_jac = self.jac(symbol.child, variable) # type: ignore[has-type] # _unary_jac defined in derived classes for specific rules jac = symbol._unary_jac(child_jac) elif isinstance(symbol, pybamm.Function): - children_jacs = [None] * len(symbol.children) + children_jacs: list[None | pybamm.Symbol] = [None] * len(symbol.children) for i, child in enumerate(symbol.children): children_jacs[i] = self.jac(child, variable) # _function_jac defined in function class diff --git a/pybamm/expression_tree/operations/latexify.py b/pybamm/expression_tree/operations/latexify.py index 572f01a560..c16ab4b83d 100644 --- a/pybamm/expression_tree/operations/latexify.py +++ b/pybamm/expression_tree/operations/latexify.py @@ -1,6 +1,8 @@ # # Latexify class # +from __future__ import annotations + import copy import re import warnings @@ -48,7 +50,7 @@ class Latexify: >>> model.latexify(newline=False)[1:5] """ - def __init__(self, model, filename=None, newline=True): + def __init__(self, model, filename: str | None = None, newline: bool = True): self.model = model self.filename = filename self.newline = newline diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index 53505dbb1f..0507b3304e 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -7,8 +7,6 @@ import numpy as np import re -from typing import Optional - class Serialise: """ @@ -48,15 +46,17 @@ def default(self, node: dict): class _MeshEncoder(json.JSONEncoder): """Converts PyBaMM meshes into a JSON-serialisable format""" - def default(self, node: dict): + def default(self, node: pybamm.Mesh): node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} if isinstance(node, pybamm.Mesh): node_dict.update(node.to_json()) - node_dict["sub_meshes"] = {} + submeshes = {} for k, v in node.items(): if len(k) == 1 and "ghost cell" not in k[0]: - node_dict["sub_meshes"][k[0]] = self.default(v) + submeshes[k[0]] = self.default(v) + + node_dict["sub_meshes"] = submeshes return node_dict @@ -80,9 +80,9 @@ class _EmptyDict(dict): def save_model( self, model: pybamm.BaseModel, - mesh: Optional[pybamm.Mesh] = None, - variables: Optional[pybamm.FuzzyDict] = None, - filename: Optional[str] = None, + mesh: pybamm.Mesh | None = None, + variables: pybamm.FuzzyDict | None = None, + filename: str | None = None, ): """Saves a discretised model to a JSON file. @@ -114,7 +114,7 @@ def save_model( "pybamm_version": pybamm.__version__, "name": model.name, "options": model.options, - "bounds": [bound.tolist() for bound in model.bounds], + "bounds": [bound.tolist() for bound in model.bounds], # type: ignore[attr-defined] "concatenated_rhs": self._SymbolEncoder().default(model._concatenated_rhs), "concatenated_algebraic": self._SymbolEncoder().default( model._concatenated_algebraic @@ -144,7 +144,7 @@ def save_model( json.dump(model_json, f) def load_model( - self, filename: str, battery_model: Optional[pybamm.BaseModel] = None + self, filename: str, battery_model: pybamm.BaseModel | None = None ) -> pybamm.BaseModel: """ Loads a discretised, ready to solve model into PyBaMM. @@ -247,12 +247,15 @@ def _get_pybamm_class(self, snippet: dict): try: empty_class = self._Empty() empty_class.__class__ = class_ + + return empty_class + except TypeError: # Mesh objects have a different layouts - empty_class = self._EmptyDict() - empty_class.__class__ = class_ + empty_dict_class = self._EmptyDict() + empty_dict_class.__class__ = class_ - return empty_class + return empty_dict_class def _deconstruct_pybamm_dicts(self, dct: dict): """ diff --git a/pybamm/expression_tree/operations/unpack_symbols.py b/pybamm/expression_tree/operations/unpack_symbols.py index 825cb2db40..1933eada76 100644 --- a/pybamm/expression_tree/operations/unpack_symbols.py +++ b/pybamm/expression_tree/operations/unpack_symbols.py @@ -1,6 +1,11 @@ # # Helper function to unpack a symbol # +from __future__ import annotations +from typing import TYPE_CHECKING, Sequence + +if TYPE_CHECKING: # pragma: no cover + import pybamm class SymbolUnpacker: @@ -16,11 +21,17 @@ class SymbolUnpacker: cached unpacked equations """ - def __init__(self, classes_to_find, unpacked_symbols=None): + def __init__( + self, + classes_to_find: Sequence[pybamm.Symbol] | pybamm.Symbol, + unpacked_symbols: dict | None = None, + ): self.classes_to_find = classes_to_find - self._unpacked_symbols = unpacked_symbols or {} + self._unpacked_symbols: dict = unpacked_symbols or {} - def unpack_list_of_symbols(self, list_of_symbols): + def unpack_list_of_symbols( + self, list_of_symbols: Sequence[pybamm.Symbol] + ) -> set[pybamm.Symbol]: """ Unpack a list of symbols. See :meth:`SymbolUnpacker.unpack()` @@ -41,7 +52,9 @@ def unpack_list_of_symbols(self, list_of_symbols): return all_instances - def unpack_symbol(self, symbol): + def unpack_symbol( + self, symbol: Sequence[pybamm.Symbol] | pybamm.Symbol + ) -> list[pybamm.Symbol]: """ This function recurses down the tree, unpacking the symbols and saving the ones that have a class in `self.classes_to_find`. diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index 787b7b5007..e646ff234d 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -1,10 +1,14 @@ # # Parameter classes # -import numbers +from __future__ import annotations import sys import numpy as np +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: # pragma: no cover + import sympy import pybamm from pybamm.util import have_optional_dependency @@ -23,26 +27,26 @@ class Parameter(pybamm.Symbol): name of the node """ - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__(name) - def create_copy(self): + def create_copy(self) -> pybamm.Parameter: """See :meth:`pybamm.Symbol.new_copy()`.""" return Parameter(self.name) - def _evaluate_for_shape(self): + def _evaluate_for_shape(self) -> float: """ Returns the scalar 'NaN' to represent the shape of a parameter. See :meth:`pybamm.Symbol.evaluate_for_shape()` """ return np.nan - def is_constant(self): + def is_constant(self) -> Literal[False]: """See :meth:`pybamm.Symbol.is_constant()`.""" # Parameter is not constant since it can become an InputParameter return False - def to_equation(self): + def to_equation(self) -> sympy.Symbol: """Convert the node and its subtree into a SymPy equation.""" sympy = have_optional_dependency("sympy") if self.print_name is not None: @@ -91,18 +95,18 @@ class FunctionParameter(pybamm.Symbol): def __init__( self, - name, - inputs, - diff_variable=None, + name: str, + inputs: dict[str, pybamm.Symbol], + diff_variable: pybamm.Symbol | None = None, print_name="calculate", - ): + ) -> None: # assign diff variable self.diff_variable = diff_variable children_list = list(inputs.values()) # Turn numbers into scalars for idx, child in enumerate(children_list): - if isinstance(child, numbers.Number): + if isinstance(child, (float, int, np.number)): children_list[idx] = pybamm.Scalar(child) domains = self.get_children_domains(children_list) @@ -116,30 +120,34 @@ def __init__( self.print_name = print_name else: frame = sys._getframe().f_back - print_name = frame.f_code.co_name - if print_name.startswith("_"): - self.print_name = None - else: - try: - parent_param = frame.f_locals["self"] - except KeyError: - parent_param = None - if hasattr(parent_param, "domain") and parent_param.domain is not None: - # add "_n" or "_s" or "_p" if this comes from a Parameter class with - # a domain - d = parent_param.domain[0] - print_name += f"_{d}" - self.print_name = print_name - - @property - def input_names(self): - return self._input_names + if frame is not None: + print_name = frame.f_code.co_name + if print_name.startswith("_"): + self.print_name = None + else: + try: + parent_param = frame.f_locals["self"] + except KeyError: + parent_param = None + if ( + hasattr(parent_param, "domain") + and parent_param.domain is not None + ): + # add "_n" or "_s" or "_p" if this comes from a Parameter class with + # a domain + d = parent_param.domain[0] + print_name += f"_{d}" + self.print_name = print_name def print_input_names(self): if self._input_names: for inp in self._input_names: print(inp) + @property + def input_names(self): + return self._input_names + @input_names.setter def input_names(self, inp=None): if inp: @@ -172,7 +180,7 @@ def set_id(self): ) ) - def diff(self, variable): + def diff(self, variable: pybamm.Symbol) -> pybamm.FunctionParameter: """See :meth:`pybamm.Symbol.diff()`.""" # return a new FunctionParameter, that knows it will need to be differentiated # when the parameters are set @@ -196,8 +204,11 @@ def create_copy(self): return out def _function_parameter_new_copy( - self, input_names, children, print_name="calculate" - ): + self, + input_names: list[str], + children: list[pybamm.Symbol], + print_name="calculate", + ) -> pybamm.FunctionParameter: """ Returns a new copy of the function parameter. @@ -231,7 +242,7 @@ def _evaluate_for_shape(self): # add 1e-16 to avoid division by zero return sum(child.evaluate_for_shape() for child in self.children) + 1e-16 - def to_equation(self): + def to_equation(self) -> sympy.Symbol: """Convert the node and its subtree into a SymPy equation.""" sympy = have_optional_dependency("sympy") if self.print_name is not None: diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index 73dccf7d6c..26bbabfcf0 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -1,10 +1,13 @@ # # Scalar class # +from __future__ import annotations import numpy as np +from typing import Literal import pybamm from pybamm.util import have_optional_dependency +from pybamm.type_definitions import Numeric class Scalar(pybamm.Symbol): @@ -21,7 +24,11 @@ class Scalar(pybamm.Symbol): """ - def __init__(self, value, name=None): + def __init__( + self, + value: Numeric, + name: str | None = None, + ) -> None: # set default name if not provided self.value = value if name is None: @@ -31,11 +38,7 @@ def __init__(self, value, name=None): @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - - instance.__init__(snippet["value"], name=snippet["name"]) - - return instance + return cls(snippet["value"], name=snippet["name"]) def __str__(self): return str(self.value) @@ -60,11 +63,17 @@ def set_id(self): # indistinguishable by class and name alone self._id = hash((self.__class__, str(self.value))) - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" return self._value - def _jac(self, variable): + def _jac(self, variable: pybamm.Variable) -> pybamm.Scalar: """See :meth:`pybamm.Symbol._jac()`.""" return pybamm.Scalar(0) @@ -72,7 +81,7 @@ def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" return Scalar(self.value, self.name) - def is_constant(self): + def is_constant(self) -> Literal[True]: """See :meth:`pybamm.Symbol.is_constant()`.""" return True diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 348f908b45..f52b41a99c 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -1,10 +1,12 @@ # # State Vector class # +from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, vstack import pybamm +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType class StateVectorBase(pybamm.Symbol): @@ -34,13 +36,13 @@ class StateVectorBase(pybamm.Symbol): def __init__( self, - *y_slices, + *y_slices: slice, base_name="y", - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - evaluation_array=None, + name: str | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, + evaluation_array: list[bool] | None = None, ): for y_slice in y_slices: if not isinstance(y_slice, slice): @@ -71,19 +73,15 @@ def __init__( @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - y_slices = [slice(s["start"], s["stop"], s["step"]) for s in snippet["y_slice"]] - instance.__init__( + return cls( *y_slices, name=snippet["name"], domains=snippet["domains"], evaluation_array=snippet["evaluation_array"], ) - return instance - @property def y_slices(self): return self._y_slices @@ -126,7 +124,7 @@ def set_id(self): ) ) - def _jac_diff_vector(self, variable): + def _jac_diff_vector(self, variable: pybamm.StateVectorBase): """ Differentiate a slice of a StateVector of size m with respect to another slice of a different StateVector of size n. This returns a (sparse) zero matrix of @@ -147,7 +145,7 @@ def _jac_diff_vector(self, variable): # Return zeros of correct size since no entries match return pybamm.Matrix(csr_matrix((slices_size, variable_size))) - def _jac_same_vector(self, variable): + def _jac_same_vector(self, variable: pybamm.StateVectorBase): """ Differentiate a slice of a StateVector of size m with respect to another slice of a StateVector of size n. This returns a (sparse) matrix of size @@ -259,12 +257,12 @@ class StateVector(StateVectorBase): def __init__( self, - *y_slices, - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - evaluation_array=None, + *y_slices: slice, + name: str | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, + evaluation_array: list[bool] | None = None, ): super().__init__( *y_slices, @@ -276,7 +274,13 @@ def __init__( evaluation_array=evaluation_array, ) - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if y is None: raise TypeError("StateVector cannot evaluate input 'y=None'") @@ -290,7 +294,7 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): out = out[:, np.newaxis] return out - def diff(self, variable): + def diff(self, variable: pybamm.Symbol): if variable == self: return pybamm.Scalar(1) if variable == pybamm.t: @@ -303,7 +307,7 @@ def diff(self, variable): else: return pybamm.Scalar(0) - def _jac(self, variable): + def _jac(self, variable: pybamm.StateVector | pybamm.StateVectorDot): if isinstance(variable, pybamm.StateVector): return self._jac_same_vector(variable) elif isinstance(variable, pybamm.StateVectorDot): @@ -337,12 +341,12 @@ class StateVectorDot(StateVectorBase): def __init__( self, - *y_slices, - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - evaluation_array=None, + *y_slices: slice, + name: str | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, + evaluation_array: list[bool] | None = None, ): super().__init__( *y_slices, @@ -354,7 +358,13 @@ def __init__( evaluation_array=evaluation_array, ) - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if y_dot is None: raise TypeError("StateVectorDot cannot evaluate input 'y_dot=None'") @@ -368,7 +378,7 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): out = out[:, np.newaxis] return out - def diff(self, variable): + def diff(self, variable: pybamm.Symbol): if variable == self: return pybamm.Scalar(1) elif variable == pybamm.t: @@ -378,7 +388,7 @@ def diff(self, variable): else: return pybamm.Scalar(0) - def _jac(self, variable): + def _jac(self, variable: pybamm.StateVector | pybamm.StateVectorDot): if isinstance(variable, pybamm.StateVectorDot): return self._jac_same_vector(variable) elif isinstance(variable, pybamm.StateVector): diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 9c719159c7..70f4e82db6 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -1,21 +1,33 @@ # # Base Symbol Class for the expression tree # +from __future__ import annotations import numbers import numpy as np from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property +from typing import TYPE_CHECKING, Sequence, cast import pybamm from pybamm.util import have_optional_dependency from pybamm.expression_tree.printing.print_name import prettify_print_name +if TYPE_CHECKING: # pragma: no cover + import casadi + from pybamm.type_definitions import ( + ChildSymbol, + ChildValue, + DomainType, + AuxiliaryDomainType, + DomainsType, + ) + DOMAIN_LEVELS = ["primary", "secondary", "tertiary", "quaternary"] -EMPTY_DOMAINS = {k: [] for k in DOMAIN_LEVELS} +EMPTY_DOMAINS: dict[str, list] = {k: [] for k in DOMAIN_LEVELS} -def domain_size(domain): +def domain_size(domain: list[str] | str): """ Get the domain size. @@ -43,7 +55,7 @@ def domain_size(domain): return size -def create_object_of_size(size, typ="vector"): +def create_object_of_size(size: int, typ="vector"): """Return object, consisting of NaNs, of the right shape.""" if typ == "vector": return np.nan * np.ones((size, 1)) @@ -51,7 +63,7 @@ def create_object_of_size(size, typ="vector"): return np.nan * np.ones((size, size)) -def evaluate_for_shape_using_domain(domains, typ="vector"): +def evaluate_for_shape_using_domain(domains: dict[str, list[str] | str], typ="vector"): """ Return a vector of the appropriate shape, based on the domains. Domain 'sizes' can clash, but are unlikely to, and won't cause failures if they do. @@ -63,11 +75,11 @@ def evaluate_for_shape_using_domain(domains, typ="vector"): return create_object_of_size(_domain_sizes, typ) -def is_constant(symbol): +def is_constant(symbol: Symbol): return isinstance(symbol, numbers.Number) or symbol.is_constant() -def is_scalar_x(expr, x): +def is_scalar_x(expr: Symbol, x: int): """ Utility function to test if an expression evaluates to a constant scalar value """ @@ -78,28 +90,28 @@ def is_scalar_x(expr, x): return False -def is_scalar_zero(expr): +def is_scalar_zero(expr: Symbol): """ Utility function to test if an expression evaluates to a constant scalar zero """ return is_scalar_x(expr, 0) -def is_scalar_one(expr): +def is_scalar_one(expr: Symbol): """ Utility function to test if an expression evaluates to a constant scalar one """ return is_scalar_x(expr, 1) -def is_scalar_minus_one(expr): +def is_scalar_minus_one(expr: Symbol): """ Utility function to test if an expression evaluates to a constant scalar minus one """ return is_scalar_x(expr, -1) -def is_matrix_x(expr, x): +def is_matrix_x(expr: Symbol, x: int): """ Utility function to test if an expression evaluates to a constant matrix value """ @@ -122,28 +134,28 @@ def is_matrix_x(expr, x): return False -def is_matrix_zero(expr): +def is_matrix_zero(expr: Symbol): """ Utility function to test if an expression evaluates to a constant matrix zero """ return is_matrix_x(expr, 0) -def is_matrix_one(expr): +def is_matrix_one(expr: Symbol): """ Utility function to test if an expression evaluates to a constant matrix one """ return is_matrix_x(expr, 1) -def is_matrix_minus_one(expr): +def is_matrix_minus_one(expr: Symbol): """ Utility function to test if an expression evaluates to a constant matrix minus one """ return is_matrix_x(expr, -1) -def simplify_if_constant(symbol): +def simplify_if_constant(symbol: pybamm.Symbol): """ Utility function to simplify an expression tree if it evalutes to a constant scalar, vector or matrix @@ -156,7 +168,9 @@ def simplify_if_constant(symbol): or (isinstance(result, np.ndarray) and result.ndim == 0) or isinstance(result, np.bool_) ): - return pybamm.Scalar(result) + # type-narrow for Scalar + new_result = cast(float, result) + return pybamm.Scalar(new_result) elif isinstance(result, np.ndarray) or issparse(result): if result.ndim == 1 or result.shape[1] == 1: return pybamm.Vector(result, domains=symbol.domains) @@ -200,11 +214,11 @@ class Symbol: def __init__( self, - name, - children=None, - domain=None, - auxiliary_domains=None, - domains=None, + name: str, + children: Sequence[Symbol] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, ): super().__init__() self.name = name @@ -213,13 +227,13 @@ def __init__( children = [] self._children = children - # Keep a separate "oprhans" attribute for backwards compatibility + # Keep a separate "orphans" attribute for backwards compatibility self._orphans = children # Set domains (and hence id) self.domains = self.read_domain_or_domains(domain, auxiliary_domains, domains) - self._saved_evaluates_on_edges = {} + self._saved_evaluates_on_edges: dict = {} self._print_name = None # Test shape on everything but nodes that contain the base Symbol class or @@ -244,14 +258,10 @@ def _from_json(cls, snippet: dict): At minimum, should contain "name", "children" and "domains". """ - instance = cls.__new__(cls) - - instance.__init__( + return cls( snippet["name"], children=snippet["children"], domains=snippet["domains"] ) - return instance - @property def children(self): """ @@ -268,7 +278,7 @@ def name(self): return self._name @name.setter - def name(self, value): + def name(self, value: str): assert isinstance(value, str) self._name = value @@ -276,30 +286,6 @@ def name(self, value): def domains(self): return self._domains - @property - def domain(self): - """ - list of applicable domains. - - Returns - ------- - iterable of str - """ - return self._domains["primary"] - - @domain.setter - def domain(self, domain): - raise NotImplementedError( - "Cannot set domain directly, use domains={'primary': domain} instead" - ) - - @property - def auxiliary_domains(self): - """Returns auxiliary domains.""" - raise NotImplementedError( - "symbol.auxiliary_domains has been deprecated, use symbol.domains instead" - ) - @domains.setter def domains(self, domains): try: @@ -345,6 +331,30 @@ def domains(self, domains): self._domains = domains self.set_id() + @property + def domain(self): + """ + list of applicable domains. + + Returns + ------- + iterable of str + """ + return self._domains["primary"] + + @domain.setter + def domain(self, domain): + raise NotImplementedError( + "Cannot set domain directly, use domains={'primary': domain} instead" + ) + + @property + def auxiliary_domains(self): + """Returns auxiliary domains.""" + raise NotImplementedError( + "symbol.auxiliary_domains has been deprecated, use symbol.domains instead" + ) + @property def secondary_domain(self): """Helper function to get the secondary domain of a symbol.""" @@ -360,7 +370,7 @@ def quaternary_domain(self): """Helper function to get the quaternary domain of a symbol.""" return self._domains["quaternary"] - def copy_domains(self, symbol): + def copy_domains(self, symbol: Symbol): """Copy the domains from a given symbol, bypassing checks.""" if self._domains != symbol._domains: self._domains = symbol._domains @@ -372,9 +382,9 @@ def clear_domains(self): self._domains = EMPTY_DOMAINS self.set_id() - def get_children_domains(self, children): + def get_children_domains(self, children: Sequence[Symbol]): """Combine domains from children, at all levels.""" - domains = {} + domains: dict = {} for child in children: for level in child.domains.keys(): if child.domains[level] == []: @@ -393,7 +403,12 @@ def get_children_domains(self, children): return domains - def read_domain_or_domains(self, domain, auxiliary_domains, domains): + def read_domain_or_domains( + self, + domain: DomainType, + auxiliary_domains: AuxiliaryDomainType, + domains: DomainsType, + ): if domains is None: if isinstance(domain, str): domain = [domain] @@ -470,7 +485,7 @@ def render(self): # pragma: no cover else: print(f"{pre}{node.name}") - def visualise(self, filename): + def visualise(self, filename: str): """ Produces a .png file of the tree (this node and its children) with the name filename @@ -497,7 +512,7 @@ def visualise(self, filename): # raise error but only through logger so that test passes pybamm.logger.error("Please install graphviz>=2.42.2 to use dot exporter") - def relabel_tree(self, symbol, counter): + def relabel_tree(self, symbol: Symbol, counter: int): """ Finds all children of a symbol and assigns them a new id so that they can be visualised properly using the graphviz output @@ -562,71 +577,71 @@ def __repr__(self): {k: v for k, v in self.domains.items() if v != []}, ) - def __add__(self, other): + def __add__(self, other: ChildSymbol) -> pybamm.Addition: """return an :class:`Addition` object.""" return pybamm.add(self, other) - def __radd__(self, other): + def __radd__(self, other: ChildSymbol) -> pybamm.Addition: """return an :class:`Addition` object.""" return pybamm.add(other, self) - def __sub__(self, other): + def __sub__(self, other: ChildSymbol) -> pybamm.Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(self, other) - def __rsub__(self, other): + def __rsub__(self, other: ChildSymbol) -> pybamm.Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(other, self) - def __mul__(self, other): + def __mul__(self, other: ChildSymbol) -> pybamm.Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(self, other) - def __rmul__(self, other): + def __rmul__(self, other: ChildSymbol) -> pybamm.Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(other, self) - def __matmul__(self, other): + def __matmul__(self, other: ChildSymbol) -> pybamm.MatrixMultiplication: """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(self, other) - def __rmatmul__(self, other): + def __rmatmul__(self, other: ChildSymbol) -> pybamm.MatrixMultiplication: """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(other, self) - def __truediv__(self, other): + def __truediv__(self, other: ChildSymbol) -> pybamm.Division: """return a :class:`Division` object.""" return pybamm.divide(self, other) - def __rtruediv__(self, other): + def __rtruediv__(self, other: ChildSymbol) -> pybamm.Division: """return a :class:`Division` object.""" return pybamm.divide(other, self) - def __pow__(self, other): + def __pow__(self, other: ChildSymbol) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(self, other) - def __rpow__(self, other): + def __rpow__(self, other: Symbol) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(other, self) - def __lt__(self, other): + def __lt__(self, other: Symbol | float) -> pybamm.NotEqualHeaviside: """return a :class:`NotEqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(self, other, False) - def __le__(self, other): + def __le__(self, other: Symbol) -> pybamm.EqualHeaviside: """return a :class:`EqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(self, other, True) - def __gt__(self, other): + def __gt__(self, other: Symbol) -> pybamm.NotEqualHeaviside: """return a :class:`NotEqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(other, self, False) - def __ge__(self, other): + def __ge__(self, other: Symbol) -> pybamm.EqualHeaviside: """return a :class:`EqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(other, self, True) - def __neg__(self): + def __neg__(self) -> pybamm.Negate: """return a :class:`Negate` object.""" if isinstance(self, pybamm.Negate): # Double negative is a positive @@ -645,7 +660,7 @@ def __neg__(self): else: return pybamm.simplify_if_constant(pybamm.Negate(self)) - def __abs__(self): + def __abs__(self) -> pybamm.AbsoluteValue: """return an :class:`AbsoluteValue` object, or a smooth approximation.""" if isinstance(self, pybamm.AbsoluteValue): # No need to apply abs a second time @@ -665,7 +680,7 @@ def __abs__(self): out = pybamm.smooth_absolute_value(self, k) return pybamm.simplify_if_constant(out) - def __mod__(self, other): + def __mod__(self, other: Symbol) -> pybamm.Modulo: """return an :class:`Modulo` object.""" return pybamm.simplify_if_constant(pybamm.Modulo(self, other)) @@ -688,7 +703,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """ return getattr(pybamm, ufunc.__name__)(*inputs, **kwargs) - def diff(self, variable): + def diff(self, variable: Symbol): """ Differentiate a symbol with respect to a variable. For any symbol that can be differentiated, return `1` if differentiating with respect to yourself, @@ -717,7 +732,12 @@ def _diff(self, variable): """ raise NotImplementedError - def jac(self, variable, known_jacs=None, clear_domain=True): + def jac( + self, + variable: pybamm.Symbol, + known_jacs: dict[pybamm.Symbol, pybamm.Symbol] | None = None, + clear_domain=True, + ): """ Differentiate a symbol with respect to a (slice of) a StateVector or StateVectorDot. @@ -738,7 +758,13 @@ def _jac(self, variable): """ raise NotImplementedError - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """ evaluate expression tree. @@ -764,7 +790,13 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): f"{self!s} of type {type(self)}" ) - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ) -> ChildValue: """Evaluate expression tree (wrapper to allow using dict of known values). Parameters @@ -816,7 +848,7 @@ def is_constant(self): # Default behaviour is False return False - def evaluate_ignoring_errors(self, t=0): + def evaluate_ignoring_errors(self, t: float | None = 0): """ Evaluates the expression. If a node exists in the tree that cannot be evaluated as a scalar or vector (e.g. Time, Parameter, Variable, StateVector), then None @@ -869,7 +901,7 @@ def evaluates_to_constant_number(self): return self.evaluates_to_number() and self.is_constant() @lru_cache - def evaluates_on_edges(self, dimension): + def evaluates_on_edges(self, dimension: str) -> bool: """ Returns True if a symbol evaluates on an edge, i.e. symbol contains a gradient operator, but not a divergence operator, and is not an IndefiniteIntegral. @@ -895,7 +927,9 @@ def _evaluates_on_edges(self, dimension): # Default behaviour: return False return False - def has_symbol_of_classes(self, symbol_classes): + def has_symbol_of_classes( + self, symbol_classes: tuple[type[Symbol], ...] | type[Symbol] + ): """ Returns True if equation has a term of the class(es) `symbol_class`. @@ -906,7 +940,14 @@ def has_symbol_of_classes(self, symbol_classes): """ return any(isinstance(symbol, symbol_classes) for symbol in self.pre_order()) - def to_casadi(self, t=None, y=None, y_dot=None, inputs=None, casadi_symbols=None): + def to_casadi( + self, + t: casadi.MX | None = None, + y: casadi.MX | None = None, + y_dot: casadi.MX | None = None, + inputs: dict | None = None, + casadi_symbols: Symbol | None = None, + ): """ Convert the expression tree to a CasADi expression tree. See :class:`pybamm.CasadiConverter`. diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 950ac16318..319499a9fc 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -1,12 +1,13 @@ # # Unary operator classes and methods # -import numbers +from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, issparse import pybamm from pybamm.util import have_optional_dependency +from pybamm.type_definitions import DomainsType class UnaryOperator(pybamm.Symbol): @@ -24,8 +25,13 @@ class UnaryOperator(pybamm.Symbol): child node """ - def __init__(self, name, child, domains=None): - if isinstance(child, numbers.Number): + def __init__( + self, + name: str, + child: pybamm.Symbol, + domains: DomainsType = None, + ): + if isinstance(child, (float, int, np.number)): child = pybamm.Scalar(child) domains = domains or child.domains @@ -70,7 +76,13 @@ def _unary_evaluate(self, child): f"{self.__class__} does not implement _unary_evaluate." ) - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """See :meth:`pybamm.Symbol.evaluate()`.""" child = self.child.evaluate(t, y, y_dot, inputs) return self._unary_evaluate(child) @@ -82,7 +94,7 @@ def _evaluate_for_shape(self): """ return self.children[0].evaluate_for_shape() - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return self.child.evaluates_on_edges(dimension) @@ -117,7 +129,7 @@ def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return f"{self.name}{self.child!s}" - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return -self.child.diff(variable) @@ -293,21 +305,18 @@ def __init__(self, child, index, name=None, check_size=True): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.UnaryOperator._from_json()`.""" - instance = cls.__new__(cls) - index = slice( snippet["index"]["start"], snippet["index"]["stop"], snippet["index"]["step"], ) - instance.__init__( + return cls( snippet["children"][0], index, name=snippet["name"], check_size=snippet["check_size"], ) - return instance def _unary_jac(self, child_jac): """See :meth:`pybamm.UnaryOperator._unary_jac()`.""" @@ -349,7 +358,7 @@ def _unary_new_copy(self, child): def _evaluate_for_shape(self): return self._unary_evaluate(self.children[0].evaluate_for_shape()) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -391,7 +400,12 @@ class with a :class:`Matrix` child node """ - def __init__(self, name, child, domains=None): + def __init__( + self, + name: str, + child: pybamm.Symbol, + domains: dict[str, list[str] | str] | None = None, + ): super().__init__(name, child, domains) def to_json(self): @@ -426,7 +440,7 @@ def __init__(self, child): ) super().__init__("grad", child) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return True @@ -460,7 +474,7 @@ def __init__(self, child): ) super().__init__("div", child) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -485,7 +499,7 @@ class Laplacian(SpatialOperator): def __init__(self, child): super().__init__("laplacian", child) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -501,7 +515,7 @@ class GradientSquared(SpatialOperator): def __init__(self, child): super().__init__("grad squared", child) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -551,7 +565,13 @@ class Integral(SpatialOperator): The variable over which to integrate """ - def __init__(self, child, integration_variable): + def __init__( + self, + child, + integration_variable: ( + list[pybamm.IndependentVariable] | pybamm.IndependentVariable + ), + ): if not isinstance(integration_variable, list): integration_variable = [integration_variable] @@ -646,7 +666,7 @@ def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -844,7 +864,7 @@ def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -882,7 +902,7 @@ def set_id(self): ) ) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -1003,11 +1023,7 @@ def __init__(self, children, initial_condition): @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - - instance.__init__(snippet["children"][0], snippet["initial_condition"]) - - return instance + return cls(snippet["children"][0], snippet["initial_condition"]) def _unary_new_copy(self, child): return self.__class__(child, self.initial_condition) @@ -1119,7 +1135,7 @@ def __init__(self, name, child): ) super().__init__(name, child) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return True @@ -1152,7 +1168,7 @@ def _unary_new_copy(self, child): """See :meth:`pybamm.Symbol.new_copy()`.""" return NotConstant(child) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.child.diff(variable) diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index 35193782e3..eb0d90cdb6 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -1,11 +1,17 @@ # # Variable class # - +from __future__ import annotations import numpy as np import numbers import pybamm from pybamm.util import have_optional_dependency +from pybamm.type_definitions import ( + DomainType, + AuxiliaryDomainType, + DomainsType, + Numeric, +) class VariableBase(pybamm.Symbol): @@ -49,14 +55,14 @@ class VariableBase(pybamm.Symbol): def __init__( self, - name, - domain=None, - auxiliary_domains=None, - domains=None, - bounds=None, - print_name=None, - scale=1, - reference=0, + name: str, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, + bounds: tuple[pybamm.Symbol] | None = None, + print_name: str | None = None, + scale: float | pybamm.Symbol | None = 1, + reference: float | pybamm.Symbol | None = 0, ): if isinstance(scale, numbers.Number): scale = pybamm.Scalar(scale) @@ -82,7 +88,7 @@ def bounds(self): return self._bounds @bounds.setter - def bounds(self, values): + def bounds(self, values: tuple[Numeric, Numeric]): if values is None: values = (-np.inf, np.inf) else: @@ -183,7 +189,7 @@ class Variable(VariableBase): Default is 0. """ - def diff(self, variable): + def diff(self, variable: pybamm.Symbol): if variable == self: return pybamm.Scalar(1) elif variable == pybamm.t: @@ -237,7 +243,7 @@ class VariableDot(VariableBase): Default is 0. """ - def get_variable(self): + def get_variable(self) -> pybamm.Variable: """ return a :class:`.Variable` corresponding to this VariableDot @@ -246,7 +252,7 @@ def get_variable(self): """ return Variable(self.name[:-1], domains=self.domains, scale=self.scale) - def diff(self, variable): + def diff(self, variable: pybamm.Symbol) -> pybamm.Scalar: if variable == self: return pybamm.Scalar(1) elif variable == pybamm.t: diff --git a/pybamm/expression_tree/vector.py b/pybamm/expression_tree/vector.py index 641c098f79..a1d8052c94 100644 --- a/pybamm/expression_tree/vector.py +++ b/pybamm/expression_tree/vector.py @@ -1,9 +1,11 @@ # # Vector class # +from __future__ import annotations import numpy as np import pybamm +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType class Vector(pybamm.Array): @@ -13,13 +15,13 @@ class Vector(pybamm.Array): def __init__( self, - entries, - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - entries_string=None, - ): + entries: np.ndarray | list[float] | np.matrix, + name: str | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, + entries_string: str | None = None, + ) -> None: if isinstance(entries, (list, np.matrix)): entries = np.array(entries) # make sure that entries are a vector (can be a column vector) diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 6b45aeb083..b6b5a9b2da 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -1,6 +1,8 @@ # # Base model class # +from __future__ import annotations + import numbers import warnings from collections import OrderedDict @@ -130,10 +132,9 @@ def deserialise(cls, properties: dict): """ Create a model instance from a serialised object. """ - instance = cls.__new__(cls) # append the model name with _saved to differentiate - instance.__init__(name=properties["name"] + "_saved") + instance = cls(name=properties["name"] + "_saved") instance.options = properties["options"] @@ -1261,7 +1262,7 @@ def save_model(self, filename=None, mesh=None, variables=None): Serialise().save_model(self, filename=filename, mesh=mesh, variables=variables) -def load_model(filename, battery_model: BaseModel = None): +def load_model(filename, battery_model: BaseModel | None = None): """ Load in a saved model from a JSON file diff --git a/pybamm/models/event.py b/pybamm/models/event.py index 5bba4cd14b..8d2695160e 100644 --- a/pybamm/models/event.py +++ b/pybamm/models/event.py @@ -1,4 +1,9 @@ +from __future__ import annotations + from enum import Enum +import numpy as np + +from typing import TypeVar class EventType(Enum): @@ -24,6 +29,9 @@ class EventType(Enum): SWITCH = 3 +E = TypeVar("E", bound="Event") + + class Event: """ @@ -47,7 +55,7 @@ def __init__(self, name, expression, event_type=EventType.TERMINATION): self._event_type = event_type @classmethod - def _from_json(cls, snippet: dict): + def _from_json(cls: type[E], snippet: dict) -> E: """ Reconstructs an Event instance during deserialisation of a JSON file. @@ -58,17 +66,19 @@ def _from_json(cls, snippet: dict): Should contain "name", "expression" and "event_type". """ - instance = cls.__new__(cls) - - instance.__init__( + return cls( snippet["name"], snippet["expression"], event_type=EventType(snippet["event_type"][1]), ) - return instance - - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | None = None, + ): """ Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate` """ diff --git a/pybamm/models/full_battery_models/base_battery_model.py b/pybamm/models/full_battery_models/base_battery_model.py index cbc270653b..1ac618ddde 100644 --- a/pybamm/models/full_battery_models/base_battery_model.py +++ b/pybamm/models/full_battery_models/base_battery_model.py @@ -809,10 +809,9 @@ def deserialise(cls, properties: dict): """ Create a model instance from a serialised object. """ - instance = cls.__new__(cls) # append the model name with _saved to differentiate - instance.__init__( + instance = cls( options=properties["options"], name=properties["name"] + "_saved" ) diff --git a/pybamm/parameters/bpx.py b/pybamm/parameters/bpx.py index 65322a5b99..c6037b0f40 100644 --- a/pybamm/parameters/bpx.py +++ b/pybamm/parameters/bpx.py @@ -118,7 +118,7 @@ def _bpx_to_param_dict(bpx: BPX) -> dict: } # Loop over each component of BPX and add to pybamm dictionary - pybamm_dict = {} + pybamm_dict: dict = {} pybamm_dict = _bpx_to_domain_param_dict( bpx.parameterisation.positive_electrode, pybamm_dict, positive_electrode ) diff --git a/pybamm/parameters/parameter_sets.py b/pybamm/parameters/parameter_sets.py index 20c20de091..7a517af418 100644 --- a/pybamm/parameters/parameter_sets.py +++ b/pybamm/parameters/parameter_sets.py @@ -3,6 +3,7 @@ import importlib.metadata import textwrap from collections.abc import Mapping +from typing import Callable class ParameterSets(Mapping): @@ -56,7 +57,7 @@ def __new__(cls): def __getitem__(self, key) -> dict: return self.__load_entry_point__(key)() - def __load_entry_point__(self, key) -> callable: + def __load_entry_point__(self, key) -> Callable: """Check that ``key`` is a registered ``pybamm_parameter_sets``, and return the entry point for the parameter set, loading it needed. """ diff --git a/pybamm/simulation.py b/pybamm/simulation.py index a4c0d4bb32..5d31340a05 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -1,6 +1,8 @@ # # Simulation class # +from __future__ import annotations + import pickle import pybamm import numpy as np @@ -10,7 +12,6 @@ from functools import lru_cache from datetime import timedelta from pybamm.util import have_optional_dependency -from typing import Optional from pybamm.expression_tree.operations.serialise import Serialise @@ -1167,7 +1168,7 @@ def save(self, filename): def save_model( self, - filename: Optional[str] = None, + filename: str | None = None, mesh: bool = False, variables: bool = False, ): diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index 3321ace947..e9976fc28c 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -1,6 +1,7 @@ # # Solver class using sundials with the KLU sparse linear solver # +# mypy: ignore-errors import casadi import pybamm import numpy as np diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index df6714e52a..2c7bdc6d17 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors import collections import operator as op from functools import partial diff --git a/pybamm/solvers/scikits_dae_solver.py b/pybamm/solvers/scikits_dae_solver.py index a5bf1e5a4f..c942e8ccd7 100644 --- a/pybamm/solvers/scikits_dae_solver.py +++ b/pybamm/solvers/scikits_dae_solver.py @@ -1,6 +1,7 @@ # # Solver class using Scipy's adaptive time stepper # +# mypy: ignore-errors import casadi import pybamm diff --git a/pybamm/solvers/scikits_ode_solver.py b/pybamm/solvers/scikits_ode_solver.py index 9f5ee67604..f3a4232da9 100644 --- a/pybamm/solvers/scikits_ode_solver.py +++ b/pybamm/solvers/scikits_ode_solver.py @@ -1,6 +1,7 @@ # # Solver class using Scipy's adaptive time stepper # +# mypy: ignore-errors import casadi import pybamm diff --git a/pybamm/type_definitions.py b/pybamm/type_definitions.py new file mode 100644 index 0000000000..c3d0a56faa --- /dev/null +++ b/pybamm/type_definitions.py @@ -0,0 +1,20 @@ +# +# Common type definitions for PyBaMM +# +from __future__ import annotations + +from typing import Union, List, Dict +from typing_extensions import TypeAlias +import numpy as np +import pybamm + +# numbers.Number should not be used for type hints +Numeric: TypeAlias = Union[int, float, np.number] + +# expression tree +ChildValue: TypeAlias = Union[float, np.ndarray] +ChildSymbol: TypeAlias = Union[float, np.ndarray, pybamm.Symbol] + +DomainType: TypeAlias = Union[List[str], str, None] +AuxiliaryDomainType: TypeAlias = Union[Dict[str, str], None] +DomainsType: TypeAlias = Union[Dict[str, Union[List[str], str]], None] diff --git a/pyproject.toml b/pyproject.toml index f067e1d1ab..13ea6df786 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -268,3 +268,15 @@ concurrency = ["multiprocessing"] ignore = [ "PP003" # list wheel as a build-dep ] + +[tool.mypy] +ignore_missing_imports = true +allow_redefinition = true +disable_error_code = ["call-overload", "operator"] + +[[tool.mypy.overrides]] +module = [ + "pybamm.models.base_model.*", + "pybamm.models.full_battery_models.base_battery_model.*" + ] +disable_error_code = "attr-defined"