From 0f28844e76897a56410e2154b7620d39c3790bef Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 19 May 2023 09:57:35 +0000 Subject: [PATCH 01/32] (wip) First batch of typing --- pybamm/expression_tree/array.py | 30 ++-- pybamm/expression_tree/averages.py | 39 +++-- pybamm/expression_tree/binary_operators.py | 157 ++++++++++++++---- pybamm/expression_tree/broadcasts.py | 73 ++++++-- pybamm/expression_tree/concatenations.py | 39 ++++- pybamm/expression_tree/functions.py | 24 ++- .../expression_tree/independent_variable.py | 39 ++++- pybamm/expression_tree/input_parameter.py | 22 ++- pybamm/expression_tree/interpolant.py | 15 +- pybamm/expression_tree/matrix.py | 15 +- .../operations/convert_to_casadi.py | 9 +- .../operations/evaluate_python.py | 28 +++- pybamm/expression_tree/operations/jacobian.py | 11 +- pybamm/expression_tree/operations/latexify.py | 6 +- .../operations/unpack_symbols.py | 13 +- pybamm/expression_tree/parameter.py | 36 ++-- pybamm/expression_tree/scalar.py | 16 +- pybamm/expression_tree/state_vector.py | 66 +++++--- pybamm/expression_tree/symbol.py | 140 ++++++++++------ pybamm/expression_tree/unary_operators.py | 35 ++-- pybamm/expression_tree/variable.py | 27 +-- pybamm/expression_tree/vector.py | 16 +- pybamm/models/event.py | 3 +- 23 files changed, 589 insertions(+), 270 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index a9141041b3..73d04e19aa 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -1,9 +1,11 @@ # # NumpyArray class # +from __future__ import annotations import numpy as np import sympy from scipy.sparse import csr_matrix, issparse +from typing import Union import pybamm @@ -36,13 +38,13 @@ class Array(pybamm.Symbol): def __init__( self, - entries, - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - entries_string=None, - ): + entries: Union[np.array, list], + name: str = None, + domain: list[str] = None, + auxiliary_domains: dict[str, str] = None, + domains: dict = None, + entries_string: str = None, + ) -> None: # if if isinstance(entries, list): entries = np.array(entries) @@ -100,7 +102,7 @@ def set_id(self): (self.__class__, self.name) + self.entries_string + tuple(self.domain) ) - def _jac(self, variable): + def _jac(self, variable) -> pybamm.Matrix: """See :meth:`pybamm.Symbol._jac()`.""" # Return zeros of correct size jac = csr_matrix((self.size, variable.evaluation_array.count(True))) @@ -115,7 +117,13 @@ def create_copy(self): entries_string=self.entries_string, ) - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" return self._entries @@ -123,13 +131,13 @@ def is_constant(self): """See :meth:`pybamm.Symbol.is_constant()`.""" return True - def to_equation(self): + def to_equation(self) -> sympy.Array: """Returns the value returned by the node when evaluated.""" entries_list = self.entries.tolist() return sympy.Array(entries_list) -def linspace(start, stop, num=50, **kwargs): +def linspace(start: float, stop: float, num=50, **kwargs) -> pybamm.Array: """ Creates a linearly spaced array by calling `numpy.linspace` with keyword arguments 'kwargs'. For a list of 'kwargs' see the diff --git a/pybamm/expression_tree/averages.py b/pybamm/expression_tree/averages.py index 6ada30d47a..5bebc76592 100644 --- a/pybamm/expression_tree/averages.py +++ b/pybamm/expression_tree/averages.py @@ -1,6 +1,7 @@ # # Classes and methods for averaging # +from typing import Union import pybamm @@ -14,13 +15,13 @@ class _BaseAverage(pybamm.Integral): The child node """ - def __init__(self, child, name, integration_variable): + def __init__(self, child: pybamm.Symbol, name: str, integration_variable): super().__init__(child, integration_variable) self.name = name class XAverage(_BaseAverage): - def __init__(self, child): + def __init__(self, child: pybamm.Symbol): if all(n in child.domain[0] for n in ["negative", "particle"]): x = pybamm.standard_spatial_vars.x_n elif all(n in child.domain[0] for n in ["positive", "particle"]): @@ -30,56 +31,56 @@ def __init__(self, child): integration_variable = x super().__init__(child, "x-average", integration_variable) - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`UnaryOperator._unary_new_copy()`.""" return x_average(child) class YZAverage(_BaseAverage): - def __init__(self, child): + def __init__(self, child: pybamm.Symbol): y = pybamm.standard_spatial_vars.y z = pybamm.standard_spatial_vars.z integration_variable = [y, z] super().__init__(child, "yz-average", integration_variable) - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`UnaryOperator._unary_new_copy()`.""" return yz_average(child) class ZAverage(_BaseAverage): - def __init__(self, child): + def __init__(self, child: pybamm.Symbol): integration_variable = [pybamm.standard_spatial_vars.z] super().__init__(child, "z-average", integration_variable) - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`UnaryOperator._unary_new_copy()`.""" return z_average(child) class RAverage(_BaseAverage): - def __init__(self, child): + def __init__(self, child: pybamm.Symbol): integration_variable = [pybamm.SpatialVariable("r", child.domain)] super().__init__(child, "r-average", integration_variable) - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`UnaryOperator._unary_new_copy()`.""" return r_average(child) class SizeAverage(_BaseAverage): - def __init__(self, child, f_a_dist): + def __init__(self, child: pybamm.Symbol, f_a_dist): R = pybamm.SpatialVariable("R", domains=child.domains, coord_sys="cartesian") integration_variable = [R] super().__init__(child, "size-average", integration_variable) self.f_a_dist = f_a_dist - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`UnaryOperator._unary_new_copy()`.""" return size_average(child, f_a_dist=self.f_a_dist) -def x_average(symbol): +def x_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the x-direction. @@ -168,7 +169,7 @@ def x_average(symbol): return XAverage(symbol) -def z_average(symbol): +def z_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the z-direction. @@ -207,7 +208,7 @@ def z_average(symbol): return ZAverage(symbol) -def yz_average(symbol): +def yz_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the y-z-direction. @@ -243,11 +244,11 @@ def yz_average(symbol): return YZAverage(symbol) -def xyz_average(symbol): +def xyz_average(symbol: pybamm.Symbol) -> pybamm.Symbol: return yz_average(x_average(symbol)) -def r_average(symbol): +def r_average(symbol: pybamm.Symbol) -> pybamm.Symbol: """ Convenience function for creating an average in the r-direction. @@ -290,7 +291,7 @@ def r_average(symbol): return RAverage(symbol) -def size_average(symbol, f_a_dist=None): +def size_average(symbol: pybamm.Symbol, f_a_dist=None) -> pybamm.Symbol: """Convenience function for averaging over particle size R using the area-weighted particle-size distribution. @@ -343,7 +344,9 @@ def size_average(symbol, f_a_dist=None): return SizeAverage(symbol, f_a_dist) -def _sum_of_averages(symbol, average_function): +def _sum_of_averages( + symbol: Union[pybamm.Addition, pybamm.Subtraction], average_function +): if isinstance(symbol, pybamm.Addition): return average_function(symbol.left) + average_function(symbol.right) elif isinstance(symbol, pybamm.Subtraction): diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 749384e9bc..5391d31c88 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -1,17 +1,22 @@ # # Binary operator classes # +from __future__ import annotations import numbers import numpy as np import sympy from scipy.sparse import csr_matrix, issparse import functools +from typing import Union, Tuple import pybamm -def _preprocess_binary(left, right): +def _preprocess_binary( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], +) -> Tuple[pybamm.PrimaryBroadcast, pybamm.PrimaryBroadcast]: if isinstance(left, numbers.Number): left = pybamm.Scalar(left) if isinstance(right, numbers.Number): @@ -60,7 +65,12 @@ class BinaryOperator(pybamm.Symbol): rhs child node (converted to :class:`Scalar` if Number) """ - def __init__(self, name, left, right): + def __init__( + self, + name: str, + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + ): left, right = _preprocess_binary(left, right) domains = self.get_children_domains([left, right]) @@ -101,7 +111,11 @@ def create_copy(self): return out - def _binary_new_copy(self, left, right): + def _binary_new_copy( + self, + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + ): """ Default behaviour for new_copy. This copies the behaviour of `_binary_evaluate`, but since `left` and `right` @@ -109,7 +123,13 @@ def _binary_new_copy(self, left, right): """ return self._binary_evaluate(left, right) - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): """See :meth:`pybamm.Symbol.evaluate()`.""" left = self.left.evaluate(t, y, y_dot, inputs) right = self.right.evaluate(t, y, y_dot, inputs) @@ -131,7 +151,7 @@ def _binary_evaluate(self, left, right): f"{self.__class__} does not implement _binary_evaluate." ) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return self.left.evaluates_on_edges(dimension) or self.right.evaluates_on_edges( dimension @@ -165,7 +185,7 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("**", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule base, exponent = self.orphans @@ -206,7 +226,7 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("+", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) + self.right.diff(variable) @@ -229,7 +249,7 @@ def __init__(self, left, right): super().__init__("-", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) - self.right.diff(variable) @@ -254,7 +274,7 @@ def __init__(self, left, right): super().__init__("*", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule left, right = self.orphans @@ -337,7 +357,7 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("/", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply quotient rule top, bottom = self.orphans @@ -381,7 +401,7 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("inner product", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule left, right = self.orphans @@ -409,11 +429,15 @@ def _binary_evaluate(self, left, right): else: return left * right - def _binary_new_copy(self, left, right): + def _binary_new_copy( + self, + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.inner(left, right) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -470,7 +494,11 @@ def _binary_evaluate(self, left, right): else: return int(left == right) - def _binary_new_copy(self, left, right): + def _binary_new_copy( + self, + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.Equality(left, right) @@ -561,7 +589,7 @@ class Modulo(BinaryOperator): def __init__(self, left, right): super().__init__("%", left, right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule left, right = self.orphans @@ -603,7 +631,7 @@ def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "minimum({!s}, {!s})".format(self.left, self.right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" left, right = self.orphans return (left <= right) * left.diff(variable) + (left > right) * right.diff( @@ -620,7 +648,11 @@ def _binary_evaluate(self, left, right): # don't raise RuntimeWarning for NaNs return np.minimum(left, right) - def _binary_new_copy(self, left, right): + def _binary_new_copy( + self, + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.minimum(left, right) @@ -639,7 +671,7 @@ def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "maximum({!s}, {!s})".format(self.left, self.right) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" left, right = self.orphans return (left >= right) * left.diff(variable) + (left < right) * right.diff( @@ -656,7 +688,11 @@ def _binary_evaluate(self, left, right): # don't raise RuntimeWarning for NaNs return np.maximum(left, right) - def _binary_new_copy(self, left, right): + def _binary_new_copy( + self, + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.maximum(left, right) @@ -665,10 +701,13 @@ def _sympy_operator(self, left, right): return sympy.Max(left, right) -def _simplify_elementwise_binary_broadcasts(left, right): +def _simplify_elementwise_binary_broadcasts( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], +): left, right = _preprocess_binary(left, right) - def unpack_broadcast_recursive(symbol): + def unpack_broadcast_recursive(symbol: pybamm.Symbol): if isinstance(symbol, pybamm.Broadcast): if symbol.child.domain == []: return symbol.orphans[0] @@ -693,7 +732,11 @@ def unpack_broadcast_recursive(symbol): return left, right -def _simplified_binary_broadcast_concatenation(left, right, operator): +def _simplified_binary_broadcast_concatenation( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + operator, +): """ Check if there are concatenations or broadcasts that we can commute the operator with @@ -733,7 +776,10 @@ def _simplified_binary_broadcast_concatenation(left, right, operator): ) -def simplified_power(left, right): +def simplified_power( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], +): left, right = _simplify_elementwise_binary_broadcasts(left, right) # Check for Concatenations and Broadcasts @@ -775,7 +821,10 @@ def simplified_power(left, right): return pybamm.simplify_if_constant(pybamm.Power(left, right)) -def add(left, right): +def add( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], +): """ Note ---- @@ -863,7 +912,10 @@ def add(left, right): return pybamm.simplify_if_constant(Addition(left, right)) -def subtract(left, right): +def subtract( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], +): """ Note ---- @@ -945,7 +997,10 @@ def subtract(left, right): return pybamm.simplify_if_constant(Subtraction(left, right)) -def multiply(left, right): +def multiply( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], +): left, right = _simplify_elementwise_binary_broadcasts(left, right) # Move constant to always be on the left @@ -1070,7 +1125,10 @@ def multiply(left, right): return Multiplication(left, right) -def divide(left, right): +def divide( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], +): left, right = _simplify_elementwise_binary_broadcasts(left, right) # anything divided by zero raises error @@ -1141,7 +1199,10 @@ def divide(left, right): return pybamm.simplify_if_constant(Division(left, right)) -def matmul(left, right): +def matmul( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], +): left, right = _preprocess_binary(left, right) if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): return pybamm.zeros_like(MatrixMultiplication(left, right)) @@ -1200,7 +1261,10 @@ def matmul(left, right): return pybamm.simplify_if_constant(MatrixMultiplication(left, right)) -def minimum(left, right): +def minimum( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], +): """ Returns the smaller of two objects, possibly with a smoothing approximation. Not to be confused with :meth:`pybamm.min`, which returns min function of child. @@ -1221,7 +1285,10 @@ def minimum(left, right): return pybamm.simplify_if_constant(out) -def maximum(left, right): +def maximum( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], +): """ Returns the larger of two objects, possibly with a smoothing approximation. Not to be confused with :meth:`pybamm.max`, which returns max function of child. @@ -1242,7 +1309,11 @@ def maximum(left, right): return pybamm.simplify_if_constant(out) -def _heaviside(left, right, equal): +def _heaviside( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + equal, +): """return a :class:`EqualHeaviside` object, or a smooth approximation.""" # Check for Concatenations and Broadcasts left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -1281,7 +1352,11 @@ def _heaviside(left, right, equal): return pybamm.simplify_if_constant(out) -def softminus(left, right, k): +def softminus( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + k: float, +): """ Softplus approximation to the minimum function. k is the smoothing parameter, set by `pybamm.settings.min_smoothing`. The recommended value is k=10. @@ -1289,7 +1364,11 @@ def softminus(left, right, k): return pybamm.log(pybamm.exp(-k * left) + pybamm.exp(-k * right)) / -k -def softplus(left, right, k): +def softplus( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + k: float, +): """ Softplus approximation to the maximum function. k is the smoothing parameter, set by `pybamm.settings.max_smoothing`. The recommended value is k=10. @@ -1297,7 +1376,11 @@ def softplus(left, right, k): return pybamm.log(pybamm.exp(k * left) + pybamm.exp(k * right)) / k -def sigmoid(left, right, k): +def sigmoid( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + k: float, +): """ Sigmoidal approximation to the heaviside function. k is the smoothing parameter, set by `pybamm.settings.heaviside_smoothing`. The recommended value is k=10. @@ -1307,7 +1390,11 @@ def sigmoid(left, right, k): return (1 + pybamm.tanh(k * (right - left))) / 2 -def source(left, right, boundary=False): +def source( + left: Union[numbers.Number, pybamm.Symbol], + right: Union[numbers.Number, pybamm.Symbol], + boundary=False, +): """ A convenience function for creating (part of) an expression tree representing a source term. This is necessary for spatial methods where the mass matrix diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 32cf2c002b..da0a05187b 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -5,6 +5,7 @@ import numpy as np from scipy.sparse import csr_matrix +from typing import Iterable, Optional import pybamm @@ -29,7 +30,9 @@ class Broadcast(pybamm.SpatialOperator): name of the node """ - def __init__(self, child, domains, name=None): + def __init__( + self, child: pybamm.Symbol, domains: Iterable[str], name: Optional[str] = None + ): if name is None: name = "broadcast" super().__init__(name, child, domains=domains) @@ -41,7 +44,7 @@ def broadcasts_to_nodes(self): else: return False - def _sympy_operator(self, child): + def _sympy_operator(self, child: pybamm.Symbol): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" return child @@ -70,7 +73,12 @@ class PrimaryBroadcast(Broadcast): name of the node """ - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: pybamm.Symbol, + broadcast_domain: Iterable[str], + name: Optional[str] = None, + ): # Convert child to scalar if it is a number if isinstance(child, numbers.Number): child = pybamm.Scalar(child) @@ -83,7 +91,9 @@ def __init__(self, child, broadcast_domain, name=None): self.broadcast_type = "primary to nodes" super().__init__(child, domains, name=name) - def check_and_set_domains(self, child, broadcast_domain): + def check_and_set_domains( + self, child: pybamm.Symbol, broadcast_domain: Iterable[str] + ): """See :meth:`Broadcast.check_and_set_domains`""" # Can only do primary broadcast from current collector to electrode, # particle-size or particle or from electrode to particle-size or particle. @@ -138,7 +148,7 @@ def check_and_set_domains(self, child, broadcast_domain): return domains - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" return self.__class__(child, self.broadcast_domain) @@ -159,7 +169,12 @@ def reduce_one_dimension(self): class PrimaryBroadcastToEdges(PrimaryBroadcast): """A primary broadcast onto the edges of the domain.""" - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: pybamm.Symbol, + broadcast_domain: Iterable[str], + name: Optional[str] = None, + ): name = name or "broadcast to edges" super().__init__(child, broadcast_domain, name) self.broadcast_type = "primary to edges" @@ -190,7 +205,12 @@ class SecondaryBroadcast(Broadcast): name of the node """ - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: pybamm.Symbol, + broadcast_domain: Iterable[str], + name: Optional[str] = None, + ): # Convert domain to list if it's a string if isinstance(broadcast_domain, str): broadcast_domain = [broadcast_domain] @@ -200,7 +220,9 @@ def __init__(self, child, broadcast_domain, name=None): self.broadcast_type = "secondary to nodes" super().__init__(child, domains, name=name) - def check_and_set_domains(self, child, broadcast_domain): + def check_and_set_domains( + self, child: pybamm.Symbol, broadcast_domain: Iterable[str] + ): """See :meth:`Broadcast.check_and_set_domains`""" if child.domain == []: raise TypeError( @@ -262,7 +284,7 @@ def check_and_set_domains(self, child, broadcast_domain): return domains - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" return SecondaryBroadcast(child, self.broadcast_domain) @@ -283,7 +305,12 @@ def reduce_one_dimension(self): class SecondaryBroadcastToEdges(SecondaryBroadcast): """A secondary broadcast onto the edges of a domain.""" - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: pybamm.Symbol, + broadcast_domain: Iterable[str], + name: Optional[str] = None, + ): name = name or "broadcast to edges" super().__init__(child, broadcast_domain, name) self.broadcast_type = "secondary to edges" @@ -314,7 +341,12 @@ class TertiaryBroadcast(Broadcast): name of the node """ - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: pybamm.Symbol, + broadcast_domain: Iterable[str], + name: Optional[str] = None, + ): # Convert domain to list if it's a string if isinstance(broadcast_domain, str): broadcast_domain = [broadcast_domain] @@ -324,7 +356,9 @@ def __init__(self, child, broadcast_domain, name=None): self.broadcast_type = "tertiary to nodes" super().__init__(child, domains, name=name) - def check_and_set_domains(self, child, broadcast_domain): + def check_and_set_domains( + self, child: pybamm.Symbol, broadcast_domain: Iterable[str] + ): """See :meth:`Broadcast.check_and_set_domains`""" if child.domains["secondary"] == []: raise TypeError( @@ -371,7 +405,7 @@ def check_and_set_domains(self, child, broadcast_domain): return domains - def _unary_new_copy(self, child): + def _unary_new_copy(self, child: pybamm.Symbol): """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" return self.__class__(child, self.broadcast_domain) @@ -392,7 +426,12 @@ def reduce_one_dimension(self): class TertiaryBroadcastToEdges(TertiaryBroadcast): """A tertiary broadcast onto the edges of a domain.""" - def __init__(self, child, broadcast_domain, name=None): + def __init__( + self, + child: pybamm.Symbol, + broadcast_domain: Iterable[str], + name: Optional[str] = None, + ): name = name or "broadcast to edges" super().__init__(child, broadcast_domain, name) self.broadcast_type = "tertiary to edges" @@ -509,7 +548,7 @@ def reduce_one_dimension(self): ) -def full_like(symbols, fill_value): +def full_like(symbols: pybamm.Symbol, fill_value: float): """ Returns an array with the same shape and domains as the sum of the input symbols, with a constant value given by `fill_value`. @@ -558,7 +597,7 @@ def full_like(symbols, fill_value): return FullBroadcast(fill_value, broadcast_domains=sum_symbol.domains) -def zeros_like(*symbols): +def zeros_like(*symbols: pybamm.Symbol): """ Returns an array with the same shape and domains as the sum of the input symbols, with each entry equal to zero. @@ -571,7 +610,7 @@ def zeros_like(*symbols): return full_like(symbols, 0) -def ones_like(*symbols): +def ones_like(*symbols: pybamm.Symbol): """ Returns an array with the same shape and domains as the sum of the input symbols, with each entry equal to one. diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 2185a0fad6..1f8a90ff52 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -7,6 +7,10 @@ import numpy as np import sympy from scipy.sparse import issparse, vstack +from typing import Optional, Iterable, TYPE_CHECKING + +if TYPE_CHECKING: + from pybamm import Concatenation, DomainConcatenation import pybamm @@ -21,7 +25,13 @@ class Concatenation(pybamm.Symbol): The symbols to concatenate """ - def __init__(self, *children, name=None, check_domain=True, concat_fun=None): + def __init__( + self, + *children: Iterable[pybamm.Symbol], + name=None, + check_domain=True, + concat_fun=None + ): # The second condition checks whether this is the base Concatenation class # or a subclass of Concatenation # (ConcatenationVariable, NumpyConcatenation, ...) @@ -51,7 +61,7 @@ def __str__(self): out = out[:-2] + ")" return out - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" children_diffs = [child.diff(variable) for child in self.children] if len(children_diffs) == 1: @@ -61,7 +71,7 @@ def _diff(self, variable): return diff - def get_children_domains(self, children): + def get_children_domains(self, children: Iterable[pybamm.Symbol]): # combine domains from children domain = [] for child in children: @@ -97,7 +107,13 @@ def _concatenation_evaluate(self, children_eval): else: return self.concatenation_function(children_eval) - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): """See :meth:`pybamm.Symbol.evaluate()`.""" children = self.children children_eval = [None] * len(children) @@ -168,7 +184,7 @@ class NumpyConcatenation(Concatenation): The equations to concatenate """ - def __init__(self, *children): + def __init__(self, *children: Iterable[pybamm.Symbol]): children = list(children) # Turn objects that evaluate to scalars to objects that evaluate to vectors, # so that we can concatenate them @@ -218,7 +234,12 @@ class DomainConcatenation(Concatenation): from `copy_this`. `mesh` is not used in this case """ - def __init__(self, children, full_mesh, copy_this=None): + def __init__( + self, + children: Iterable[pybamm.Symbol], + full_mesh, # pybamm.BaseMesh + copy_this=None, #: Optional[pybamm.DomainConcatenation] + ): # Convert any constant symbols in children to a Vector of the right size for # concatenation children = list(children) @@ -330,7 +351,7 @@ class SparseStack(Concatenation): The equations to concatenate """ - def __init__(self, *children): + def __init__(self, *children): #: Iterable[pybamm.Concatenation] children = list(children) if not any(issparse(child.evaluate_for_shape()) for child in children): concatenation_function = np.vstack @@ -389,13 +410,13 @@ def __init__(self, *children): self.print_name = print_name -def substrings(s): +def substrings(s: str): for i in range(len(s)): for j in range(i, len(s)): yield s[i : j + 1] -def intersect(s1, s2): +def intersect(s1: str, s2: str): # find all the common strings between two strings all_intersects = set(substrings(s1)) & set(substrings(s2)) # intersect is the longest such intercept diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 80c2848ad9..beb8e5fcc9 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -1,12 +1,14 @@ # # Function classes and methods # +from __future__ import annotations import numbers import autograd import numpy as np import sympy from scipy import special +from typing import Optional import pybamm @@ -33,8 +35,8 @@ class Function(pybamm.Symbol): def __init__( self, function, - *children, - name=None, + *children: pybamm.Symbol, + name: Optional[str] = None, derivative="autograd", differentiated_function=None, ): @@ -67,7 +69,7 @@ def __str__(self): out = out[:-2] + ")" return out - def diff(self, variable): + def diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol.diff()`.""" if variable == self: return pybamm.Scalar(1) @@ -91,7 +93,7 @@ def diff(self, variable): return derivative - def _function_diff(self, children, idx): + def _function_diff(self, children: pybamm.Symbol, idx: float): """ Derivative with respect to child number 'idx'. See :meth:`pybamm.Symbol._diff()`. @@ -141,14 +143,20 @@ def _function_jac(self, children_jacs): return jacobian - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): """See :meth:`pybamm.Symbol.evaluate()`.""" evaluated_children = [ child.evaluate(t, y, y_dot, inputs) for child in self.children ] return self._function_evaluate(evaluated_children) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return any(child.evaluates_on_edges(dimension) for child in self.children) @@ -172,7 +180,7 @@ def create_copy(self): children_copy = [child.new_copy() for child in self.children] return self._function_new_copy(children_copy) - def _function_new_copy(self, children): + def _function_new_copy(self, children: list) -> Function: """ Returns a new copy of the function. @@ -241,7 +249,7 @@ class SpecificFunction(Function): The child to apply the function to """ - def __init__(self, function, child): + def __init__(self, function, child: pybamm.Symbol): super().__init__(function, child) def _function_new_copy(self, children): diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index efeb73f8bc..20c05b84c4 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -2,6 +2,7 @@ # IndependentVariable class # import sympy +import numpy as np import pybamm @@ -29,7 +30,13 @@ class IndependentVariable(pybamm.Symbol): deprecated. """ - def __init__(self, name, domain=None, auxiliary_domains=None, domains=None): + def __init__( + self, + name: str, + domain: list[str] = None, + auxiliary_domains: dict = None, + domains: dict = None, + ) -> None: super().__init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains ) @@ -38,11 +45,11 @@ def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) - def _jac(self, variable): + def _jac(self, variable) -> pybamm.Scalar: """See :meth:`pybamm.Symbol._jac()`.""" return pybamm.Scalar(0) - def to_equation(self): + def to_equation(self) -> sympy.Symbol: """Convert the node and its subtree into a SymPy equation.""" if self.print_name is not None: return sympy.Symbol(self.print_name) @@ -62,7 +69,13 @@ def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" return Time() - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if t is None: raise ValueError("t must be provided") @@ -101,8 +114,13 @@ class SpatialVariable(IndependentVariable): """ def __init__( - self, name, domain=None, auxiliary_domains=None, domains=None, coord_sys=None - ): + self, + name: str, + domain: list[str] = None, + auxiliary_domains: dict = None, + domains: dict = None, + coord_sys=None, + ) -> None: self.coord_sys = coord_sys super().__init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains @@ -159,8 +177,13 @@ class SpatialVariableEdge(SpatialVariable): """ def __init__( - self, name, domain=None, auxiliary_domains=None, domains=None, coord_sys=None - ): + self, + name: str, + domain: list[str] = None, + auxiliary_domains: dict = None, + domains: dict = None, + coord_sys=None, + ) -> None: super().__init__(name, domain, auxiliary_domains, domains, coord_sys) def _evaluates_on_edges(self, dimension): diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index 62c08bf0fd..f0c81fee68 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -1,11 +1,14 @@ # # Parameter classes # +from __future__ import annotations import numbers import numpy as np import scipy.sparse import pybamm +from typing import Union, Iterable + class InputParameter(pybamm.Symbol): """ @@ -25,7 +28,12 @@ class InputParameter(pybamm.Symbol): The size of the input parameter expected, defaults to 1 (scalar input) """ - def __init__(self, name, domain=None, expected_size=None): + def __init__( + self, + name: str, + domain: Union[Iterable[str], str] = None, + expected_size: int = None, + ) -> None: # Expected size defaults to 1 if no domain else None (gets set later) if expected_size is None: if domain is None: @@ -35,7 +43,7 @@ def __init__(self, name, domain=None, expected_size=None): self._expected_size = expected_size super().__init__(name, domain=domain) - def create_copy(self): + def create_copy(self) -> pybamm.InputParameter: """See :meth:`pybamm.Symbol.new_copy()`.""" new_input_parameter = InputParameter( self.name, self.domain, expected_size=self._expected_size @@ -54,7 +62,7 @@ def _evaluate_for_shape(self): else: return np.nan * np.ones((self._expected_size, 1)) - def _jac(self, variable): + def _jac(self, variable: pybamm.Variable) -> pybamm.Matrix: """See :meth:`pybamm.Symbol._jac()`.""" n_variable = variable.evaluation_array.count(True) nan_vector = self._evaluate_for_shape() @@ -65,7 +73,13 @@ def _jac(self, variable): zero_matrix = scipy.sparse.csr_matrix((n_self, n_variable)) return pybamm.Matrix(zero_matrix) - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): # inputs should be a dictionary # convert 'None' to empty dictionary for more informative error if inputs is None: diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index cd0df4d077..8d4eee1bde 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -3,6 +3,7 @@ # import numpy as np from scipy import interpolate +from typing import Iterable, Optional import warnings import pybamm @@ -41,13 +42,13 @@ class Interpolant(pybamm.Function): def __init__( self, - x, - y, - children, - name=None, - interpolator="linear", - extrapolate=True, - entries_string=None, + x: np.ndarray, + y: np.ndarray, + children: Iterable[pybamm.Symbol], + name: Optional[str] = None, + interpolator: Optional[str] = "linear", + extrapolate: Optional[bool] = True, + entries_string: Optional[str] = None, ): # "cubic spline" has been renamed to "cubic" if interpolator == "cubic spline": diff --git a/pybamm/expression_tree/matrix.py b/pybamm/expression_tree/matrix.py index d491fd129d..0c23e059e3 100644 --- a/pybamm/expression_tree/matrix.py +++ b/pybamm/expression_tree/matrix.py @@ -3,6 +3,7 @@ # import numpy as np from scipy.sparse import csr_matrix, issparse +from typing import Union import pybamm @@ -14,13 +15,13 @@ class Matrix(pybamm.Array): def __init__( self, - entries, - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - entries_string=None, - ): + entries: Union[np.ndarray, list], + name: str = None, + domain: list[str] = None, + auxiliary_domains: dict[str, str] = None, + domains: dict = None, + entries_string: str = None, + ) -> None: if isinstance(entries, list): entries = np.array(entries) if name is None: diff --git a/pybamm/expression_tree/operations/convert_to_casadi.py b/pybamm/expression_tree/operations/convert_to_casadi.py index b3a048b1f1..2aa1bfa720 100644 --- a/pybamm/expression_tree/operations/convert_to_casadi.py +++ b/pybamm/expression_tree/operations/convert_to_casadi.py @@ -13,7 +13,14 @@ def __init__(self, casadi_symbols=None): pybamm.citations.register("Andersson2019") - def convert(self, symbol, t, y, y_dot, inputs): + def convert( + self, + symbol: pybamm.Symbol, + t: casadi.MX, + y: casadi.MX, + y_dot: casadi.MX, + inputs: dict[casadi.MX], + ) -> casadi.MX: """ This function recurses down the tree, converting the PyBaMM expression tree to a CasADi expression tree diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index dc80961a77..7560f950fe 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -3,6 +3,8 @@ # import numbers from collections import OrderedDict +from numpy.typing import ArrayLike +from typing import Tuple import numpy as np import scipy.sparse @@ -14,6 +16,7 @@ from jax.config import config config.update("jax_enable_x64", True) + # jax.typing unavailable for supported version class JaxCooMatrix: @@ -37,7 +40,9 @@ class JaxCooMatrix: where x is the number of rows, and y the number of columns of the matrix """ - def __init__(self, row, col, data, shape): + def __init__( + self, row: ArrayLike, col: ArrayLike, data: ArrayLike, shape: tuple[int, int] + ): if not pybamm.have_jax(): # pragma: no cover raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/GNU-linux.html#optional-jaxsolver" # noqa: E501 @@ -54,7 +59,7 @@ def toarray(self): result = jax.numpy.zeros(self.shape, dtype=self.data.dtype) return result.at[self.row, self.col].add(self.data) - def dot_product(self, b): + def dot_product(self, b: ArrayLike): """ dot product of matrix with a dense column vector b @@ -67,7 +72,7 @@ def dot_product(self, b): result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype) return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col]) - def scalar_multiply(self, b): + def scalar_multiply(self, b: ArrayLike): """ multiply of matrix with a scalar b @@ -90,7 +95,7 @@ def __matmul__(self, b): return self.dot_product(b) -def create_jax_coo_matrix(value): +def create_jax_coo_matrix(value: scipy.sparse): """ Creates a JaxCooMatrix from a scipy.sparse matrix @@ -130,7 +135,12 @@ def is_scalar(arg): return np.all(np.array(arg.shape) == 1) -def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): +def find_symbols( + symbol: pybamm.Symbol, + constant_symbols: OrderedDict, + variable_symbols: OrderedDict, + output_jax=False, +): """ This function converts an expression tree to a dictionary of node id's and strings specifying valid python code to calculate that nodes value, given y and t. @@ -375,7 +385,9 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False): variable_symbols[symbol.id] = symbol_str -def to_python(symbol, debug=False, output_jax=False): +def to_python( + symbol: pybamm.Symbol, debug=False, output_jax=False +) -> Tuple[OrderedDict, str, bool]: """ This function converts an expression tree into a dict of constant input values, and valid python code that acts like the tree's :func:`pybamm.Symbol.evaluate` function @@ -441,7 +453,7 @@ class EvaluatorPython: """ - def __init__(self, symbol): + def __init__(self, symbol: pybamm.Symbol): constants, python_str = pybamm.to_python(symbol, debug=False) # extract constants in generated function @@ -533,7 +545,7 @@ class EvaluatorJax: """ - def __init__(self, symbol): + def __init__(self, symbol: pybamm.Symbol): if not pybamm.have_jax(): # pragma: no cover raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/GNU-linux.html#optional-jaxsolver" # noqa: E501 diff --git a/pybamm/expression_tree/operations/jacobian.py b/pybamm/expression_tree/operations/jacobian.py index 56511827b0..43d314db6e 100644 --- a/pybamm/expression_tree/operations/jacobian.py +++ b/pybamm/expression_tree/operations/jacobian.py @@ -1,6 +1,7 @@ # # Calculate the Jacobian of a symbol # +from typing import Optional import pybamm @@ -18,11 +19,15 @@ class Jacobian(object): whether or not the Jacobian clears the domain (default True) """ - def __init__(self, known_jacs=None, clear_domain=True): + def __init__( + self, + known_jacs: Optional[dict[str, pybamm.Symbol]] = None, + clear_domain: Optional[bool] = True, + ): self._known_jacs = known_jacs or {} self._clear_domain = clear_domain - def jac(self, symbol, variable): + def jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol) -> pybamm.Symbol: """ This function recurses down the tree, computing the Jacobian using the Jacobians defined in classes derived from pybamm.Symbol. E.g. the @@ -52,7 +57,7 @@ def jac(self, symbol, variable): self._known_jacs[symbol] = jac return jac - def _jac(self, symbol, variable): + def _jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol): """See :meth:`Jacobian.jac()`.""" if isinstance(symbol, pybamm.BinaryOperator): diff --git a/pybamm/expression_tree/operations/latexify.py b/pybamm/expression_tree/operations/latexify.py index 67e0199656..51195a0dcd 100644 --- a/pybamm/expression_tree/operations/latexify.py +++ b/pybamm/expression_tree/operations/latexify.py @@ -5,6 +5,8 @@ import re import warnings +from typing import Optional + import sympy import pybamm @@ -49,7 +51,9 @@ class Latexify: >>> model.latexify(newline=False)[1:5] """ - def __init__(self, model, filename=None, newline=True): + def __init__( + self, model, filename: Optional[str] = None, newline: Optional[bool] = True + ): self.model = model self.filename = filename self.newline = newline diff --git a/pybamm/expression_tree/operations/unpack_symbols.py b/pybamm/expression_tree/operations/unpack_symbols.py index 96cbca39fd..9868e55c48 100644 --- a/pybamm/expression_tree/operations/unpack_symbols.py +++ b/pybamm/expression_tree/operations/unpack_symbols.py @@ -1,6 +1,11 @@ # # Helper function to unpack a symbol # +from __future__ import annotations +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + import pybamm class SymbolUnpacker(object): @@ -16,11 +21,13 @@ class SymbolUnpacker(object): cached unpacked equations """ - def __init__(self, classes_to_find, unpacked_symbols=None): + def __init__(self, classes_to_find, unpacked_symbols: Optional[set] = None): self.classes_to_find = classes_to_find self._unpacked_symbols = unpacked_symbols or {} - def unpack_list_of_symbols(self, list_of_symbols): + def unpack_list_of_symbols( + self, list_of_symbols: list[pybamm.Symbol] + ) -> list[pybamm.Symbol]: """ Unpack a list of symbols. See :meth:`SymbolUnpacker.unpack()` @@ -41,7 +48,7 @@ def unpack_list_of_symbols(self, list_of_symbols): return all_instances - def unpack_symbol(self, symbol): + def unpack_symbol(self, symbol: list[pybamm.Symbol]) -> list[pybamm.Symbol]: """ This function recurses down the tree, unpacking the symbols and saving the ones that have a class in `self.classes_to_find`. diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index 10addae464..97542b28c2 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -1,11 +1,16 @@ # # Parameter classes # +from __future__ import annotations import numbers import sys import numpy as np import sympy +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from pybamm import FunctionParameter import pybamm @@ -23,26 +28,26 @@ class Parameter(pybamm.Symbol): name of the node """ - def __init__(self, name): + def __init__(self, name: str) -> None: super().__init__(name) - def create_copy(self): + def create_copy(self) -> pybamm.Parameter: """See :meth:`pybamm.Symbol.new_copy()`.""" return Parameter(self.name) - def _evaluate_for_shape(self): + def _evaluate_for_shape(self) -> np.nan: """ Returns the scalar 'NaN' to represent the shape of a parameter. See :meth:`pybamm.Symbol.evaluate_for_shape()` """ return np.nan - def is_constant(self): + def is_constant(self) -> False: """See :meth:`pybamm.Symbol.is_constant()`.""" # Parameter is not constant since it can become an InputParameter return False - def to_equation(self): + def to_equation(self) -> sympy.Symbol: """Convert the node and its subtree into a SymPy equation.""" if self.print_name is not None: return sympy.Symbol(self.print_name) @@ -79,11 +84,11 @@ class FunctionParameter(pybamm.Symbol): def __init__( self, - name, - inputs, - diff_variable=None, + name: str, + inputs: dict[str, pybamm.Symbol], + diff_variable: Optional[pybamm.Symbol] = None, print_name="calculate", - ): + ) -> None: # assign diff variable self.diff_variable = diff_variable children_list = list(inputs.values()) @@ -129,7 +134,7 @@ def print_input_names(self): print(inp) @input_names.setter - def input_names(self, inp=None): + def input_names(self, inp: dict[str, pybamm.Symbol] = None): if inp: if inp.__class__ is list: for i in inp: @@ -156,7 +161,7 @@ def set_id(self): + tuple(self.domain) ) - def diff(self, variable): + def diff(self, variable: pybamm.Symbol) -> pybamm.FunctionParameter: """See :meth:`pybamm.Symbol.diff()`.""" # return a new FunctionParameter, that knows it will need to be differentiated # when the parameters are set @@ -180,8 +185,11 @@ def create_copy(self): return out def _function_parameter_new_copy( - self, input_names, children, print_name="calculate" - ): + self, + input_names: list[str], + children: list[pybamm.Symbol], + print_name="calculate", + ) -> pybamm.FunctionParameter: """ Returns a new copy of the function parameter. @@ -215,7 +223,7 @@ def _evaluate_for_shape(self): # add 1e-16 to avoid division by zero return sum(child.evaluate_for_shape() for child in self.children) + 1e-16 - def to_equation(self): + def to_equation(self) -> sympy.Symbol: """Convert the node and its subtree into a SymPy equation.""" if self.print_name is not None: return sympy.Symbol(self.print_name) diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index 3149bf7bee..cae9e9c3c4 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -1,8 +1,10 @@ # # Scalar class # +from __future__ import annotations import numpy as np import sympy +from typing import Optional import pybamm @@ -21,7 +23,7 @@ class Scalar(pybamm.Symbol): """ - def __init__(self, value, name=None): + def __init__(self, value: float, name: Optional[str] = None) -> None: # set default name if not provided self.value = value if name is None: @@ -52,11 +54,17 @@ def set_id(self): # indistinguishable by class and name alone self._id = hash((self.__class__, str(self.value))) - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" return self._value - def _jac(self, variable): + def _jac(self, variable: pybamm.Variable) -> pybamm.Scalar: """See :meth:`pybamm.Symbol._jac()`.""" return pybamm.Scalar(0) @@ -64,7 +72,7 @@ def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" return Scalar(self.value, self.name) - def is_constant(self): + def is_constant(self) -> True: """See :meth:`pybamm.Symbol.is_constant()`.""" return True diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 6ef8bee904..fd4cc016d2 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -1,8 +1,10 @@ # # State Vector class # +from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, vstack +from typing import Optional, Iterable, Union import pybamm @@ -34,13 +36,13 @@ class StateVectorBase(pybamm.Symbol): def __init__( self, - *y_slices, + *y_slices: slice, base_name="y", - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - evaluation_array=None, + name: Optional[str] = None, + domain: Iterable[str] = None, + auxiliary_domains: Optional[dict[str]] = None, + domains: Optional[dict] = None, + evaluation_array: Optional[list] = None, ): for y_slice in y_slices: if not isinstance(y_slice, slice): @@ -111,7 +113,7 @@ def set_id(self): + tuple(self.domain) ) - def _jac_diff_vector(self, variable): + def _jac_diff_vector(self, variable: pybamm.Symbol): """ Differentiate a slice of a StateVector of size m with respect to another slice of a different StateVector of size n. This returns a (sparse) zero matrix of @@ -132,7 +134,7 @@ def _jac_diff_vector(self, variable): # Return zeros of correct size since no entries match return pybamm.Matrix(csr_matrix((slices_size, variable_size))) - def _jac_same_vector(self, variable): + def _jac_same_vector(self, variable: pybamm.Symbol): """ Differentiate a slice of a StateVector of size m with respect to another slice of a StateVector of size n. This returns a (sparse) matrix of size @@ -222,12 +224,12 @@ class StateVector(StateVectorBase): def __init__( self, - *y_slices, - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - evaluation_array=None, + *y_slices: slice, + name: Optional[str] = None, + domain: Iterable[str] = None, + auxiliary_domains: Optional[dict[str]] = None, + domains: Optional[dict] = None, + evaluation_array: Optional[list] = None, ): super().__init__( *y_slices, @@ -239,7 +241,13 @@ def __init__( evaluation_array=evaluation_array, ) - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if y is None: raise TypeError("StateVector cannot evaluate input 'y=None'") @@ -253,7 +261,7 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): out = out[:, np.newaxis] return out - def diff(self, variable): + def diff(self, variable: pybamm.Symbol): if variable == self: return pybamm.Scalar(1) if variable == pybamm.t: @@ -266,7 +274,7 @@ def diff(self, variable): else: return pybamm.Scalar(0) - def _jac(self, variable): + def _jac(self, variable: Union[pybamm.StateVector, pybamm.StateVectorDot]): if isinstance(variable, pybamm.StateVector): return self._jac_same_vector(variable) elif isinstance(variable, pybamm.StateVectorDot): @@ -300,12 +308,12 @@ class StateVectorDot(StateVectorBase): def __init__( self, - *y_slices, - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - evaluation_array=None, + *y_slices: slice, + name: Optional[str] = None, + domain: Iterable[str] = None, + auxiliary_domains: Optional[dict[str]] = None, + domains: Optional[dict] = None, + evaluation_array: Optional[list] = None, ): super().__init__( *y_slices, @@ -317,7 +325,13 @@ def __init__( evaluation_array=evaluation_array, ) - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if y_dot is None: raise TypeError("StateVectorDot cannot evaluate input 'y_dot=None'") @@ -331,7 +345,7 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): out = out[:, np.newaxis] return out - def diff(self, variable): + def diff(self, variable: pybamm.Symbol): if variable == self: return pybamm.Scalar(1) elif variable == pybamm.t: @@ -341,7 +355,7 @@ def diff(self, variable): else: return pybamm.Scalar(0) - def _jac(self, variable): + def _jac(self, variable: Union[pybamm.StateVector, pybamm.StateVectorDot]): if isinstance(variable, pybamm.StateVectorDot): return self._jac_same_vector(variable) elif isinstance(variable, pybamm.StateVector): diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 5d28884ed5..39360204c3 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -1,6 +1,7 @@ # # Base Symbol Class for the expression tree # +from __future__ import annotations import numbers import anytree @@ -9,15 +10,25 @@ from anytree.exporter import DotExporter from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property +from typing import Union, TYPE_CHECKING, Optional import pybamm from pybamm.expression_tree.printing.print_name import prettify_print_name +if TYPE_CHECKING: + from pybamm.expression_tree.binary_operators import ( + Addition, + Subtraction, + Multiplication, + Division, + ) + import casadi + DOMAIN_LEVELS = ["primary", "secondary", "tertiary", "quaternary"] EMPTY_DOMAINS = {k: [] for k in DOMAIN_LEVELS} -def domain_size(domain): +def domain_size(domain: Union[list[str], str]): """ Get the domain size. @@ -45,7 +56,7 @@ def domain_size(domain): return size -def create_object_of_size(size, typ="vector"): +def create_object_of_size(size: int, typ="vector"): """Return object, consisting of NaNs, of the right shape.""" if typ == "vector": return np.nan * np.ones((size, 1)) @@ -53,7 +64,7 @@ def create_object_of_size(size, typ="vector"): return np.nan * np.ones((size, size)) -def evaluate_for_shape_using_domain(domains, typ="vector"): +def evaluate_for_shape_using_domain(domains: dict, typ="vector"): """ Return a vector of the appropriate shape, based on the domains. Domain 'sizes' can clash, but are unlikely to, and won't cause failures if they do. @@ -65,11 +76,11 @@ def evaluate_for_shape_using_domain(domains, typ="vector"): return create_object_of_size(_domain_sizes, typ) -def is_constant(symbol): +def is_constant(symbol: Symbol): return isinstance(symbol, numbers.Number) or symbol.is_constant() -def is_scalar_x(expr, x): +def is_scalar_x(expr: Symbol, x: int): """ Utility function to test if an expression evaluates to a constant scalar value """ @@ -80,28 +91,28 @@ def is_scalar_x(expr, x): return False -def is_scalar_zero(expr): +def is_scalar_zero(expr: Symbol): """ Utility function to test if an expression evaluates to a constant scalar zero """ return is_scalar_x(expr, 0) -def is_scalar_one(expr): +def is_scalar_one(expr: Symbol): """ Utility function to test if an expression evaluates to a constant scalar one """ return is_scalar_x(expr, 1) -def is_scalar_minus_one(expr): +def is_scalar_minus_one(expr: Symbol): """ Utility function to test if an expression evaluates to a constant scalar minus one """ return is_scalar_x(expr, -1) -def is_matrix_x(expr, x): +def is_matrix_x(expr: Symbol, x): """ Utility function to test if an expression evaluates to a constant matrix value """ @@ -124,28 +135,30 @@ def is_matrix_x(expr, x): return False -def is_matrix_zero(expr): +def is_matrix_zero(expr: Symbol): """ Utility function to test if an expression evaluates to a constant matrix zero """ return is_matrix_x(expr, 0) -def is_matrix_one(expr): +def is_matrix_one(expr: Symbol): """ Utility function to test if an expression evaluates to a constant matrix one """ return is_matrix_x(expr, 1) -def is_matrix_minus_one(expr): +def is_matrix_minus_one(expr: Symbol): """ Utility function to test if an expression evaluates to a constant matrix minus one """ return is_matrix_x(expr, -1) -def simplify_if_constant(symbol): +def simplify_if_constant( + symbol, +): # division, Negate (unary operator), Maximum (binary), multiplication, addition, subtraction """ Utility function to simplify an expression tree if it evalutes to a constant scalar, vector or matrix @@ -202,11 +215,11 @@ class Symbol: def __init__( self, - name, - children=None, - domain=None, - auxiliary_domains=None, - domains=None, + name: str, + children: list[Symbol] = None, + domain: Union[list[str], str] = None, + auxiliary_domains: dict[str, str] = None, + domains: dict = None, ): super(Symbol, self).__init__() self.name = name @@ -250,7 +263,7 @@ def name(self): return self._name @name.setter - def name(self, value): + def name(self, value: str): assert isinstance(value, str) self._name = value @@ -270,7 +283,7 @@ def domain(self): return self._domains["primary"] @domain.setter - def domain(self, domain): + def domain(self, domain: Union[list[str], str]): raise NotImplementedError( "Cannot set domain directly, use domains={'primary': domain} instead" ) @@ -283,7 +296,7 @@ def auxiliary_domains(self): ) @domains.setter - def domains(self, domains): + def domains(self, domains: dict): try: if ( self._domains == domains @@ -342,7 +355,7 @@ def quaternary_domain(self): """Helper function to get the quaternary domain of a symbol.""" return self._domains["quaternary"] - def copy_domains(self, symbol): + def copy_domains(self, symbol: Symbol): """Copy the domains from a given symbol, bypassing checks.""" if self._domains != symbol._domains: self._domains = symbol._domains @@ -354,7 +367,7 @@ def clear_domains(self): self._domains = EMPTY_DOMAINS self.set_id() - def get_children_domains(self, children): + def get_children_domains(self, children: list[Symbol]): """Combine domains from children, at all levels.""" domains = {} for child in children: @@ -375,7 +388,12 @@ def get_children_domains(self, children): return domains - def read_domain_or_domains(self, domain, auxiliary_domains, domains): + def read_domain_or_domains( + self, + domain: Union[list[str], str], + auxiliary_domains: dict[str, str], + domains: dict, + ): if domains is None: if isinstance(domain, str): domain = [domain] @@ -418,7 +436,7 @@ def scale(self): def reference(self): return self._reference - def __eq__(self, other): + def __eq__(self, other: Symbol): try: return self._id == other._id except AttributeError: @@ -448,7 +466,7 @@ def render(self): # pragma: no cover else: print("{}{}".format(pre, node.name)) - def visualise(self, filename): + def visualise(self, filename: str): """ Produces a .png file of the tree (this node and its children) with the name filename @@ -538,71 +556,71 @@ def __repr__(self): {k: v for k, v in self.domains.items() if v != []}, ) - def __add__(self, other): + def __add__(self, other: Symbol) -> Addition: """return an :class:`Addition` object.""" return pybamm.add(self, other) - def __radd__(self, other): + def __radd__(self, other: Symbol) -> Addition: """return an :class:`Addition` object.""" return pybamm.add(other, self) - def __sub__(self, other): + def __sub__(self, other: Symbol) -> Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(self, other) - def __rsub__(self, other): + def __rsub__(self, other: Symbol) -> Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(other, self) - def __mul__(self, other): + def __mul__(self, other: Symbol) -> Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(self, other) - def __rmul__(self, other): + def __rmul__(self, other: Symbol) -> Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(other, self) - def __matmul__(self, other): + def __matmul__(self, other: Symbol): """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(self, other) - def __rmatmul__(self, other): + def __rmatmul__(self, other: Symbol): """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(other, self) - def __truediv__(self, other): + def __truediv__(self, other: Symbol) -> Division: """return a :class:`Division` object.""" return pybamm.divide(self, other) - def __rtruediv__(self, other): + def __rtruediv__(self, other: Symbol) -> Division: """return a :class:`Division` object.""" return pybamm.divide(other, self) - def __pow__(self, other): + def __pow__(self, other: Symbol) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(self, other) - def __rpow__(self, other): + def __rpow__(self, other: Symbol) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(other, self) - def __lt__(self, other): + def __lt__(self, other: Symbol) -> pybamm.NotEqualHeaviside: """return a :class:`NotEqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(self, other, False) - def __le__(self, other): + def __le__(self, other: Symbol) -> pybamm.EqualHeaviside: """return a :class:`EqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(self, other, True) - def __gt__(self, other): + def __gt__(self, other: Symbol) -> pybamm.NotEqualHeaviside: """return a :class:`NotEqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(other, self, False) - def __ge__(self, other): + def __ge__(self, other: Symbol) -> pybamm.EqualHeaviside: """return a :class:`EqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(other, self, True) - def __neg__(self): + def __neg__(self) -> pybamm.Negate: """return a :class:`Negate` object.""" if isinstance(self, pybamm.Negate): # Double negative is a positive @@ -621,7 +639,7 @@ def __neg__(self): else: return pybamm.simplify_if_constant(pybamm.Negate(self)) - def __abs__(self): + def __abs__(self) -> pybamm.AbsoluteValue: """return an :class:`AbsoluteValue` object, or a smooth approximation.""" if isinstance(self, pybamm.AbsoluteValue): # No need to apply abs a second time @@ -641,7 +659,7 @@ def __abs__(self): out = pybamm.smooth_absolute_value(self, k) return pybamm.simplify_if_constant(out) - def __mod__(self, other): + def __mod__(self, other: Symbol) -> pybamm.Modulo: """return an :class:`Modulo` object.""" return pybamm.simplify_if_constant(pybamm.Modulo(self, other)) @@ -655,7 +673,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """ return getattr(pybamm, ufunc.__name__)(*inputs, **kwargs) - def diff(self, variable): + def diff(self, variable: Symbol): """ Differentiate a symbol with respect to a variable. For any symbol that can be differentiated, return `1` if differentiating with respect to yourself, @@ -684,7 +702,12 @@ def _diff(self, variable): """ raise NotImplementedError - def jac(self, variable, known_jacs=None, clear_domain=True): + def jac( + self, + variable: pybamm.Symbol, + known_jacs: Optional[dict[str, pybamm.Symbol]] = None, + clear_domain=True, + ): """ Differentiate a symbol with respect to a (slice of) a StateVector or StateVectorDot. @@ -705,7 +728,13 @@ def _jac(self, variable): """ raise NotImplementedError - def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def _base_evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): """ evaluate expression tree. @@ -731,7 +760,9 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): "{!s} of type {}".format(self, type(self)) ) - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, t=None, y: np.array = None, y_dot: np.array = None, inputs: dict = None + ): """Evaluate expression tree (wrapper to allow using dict of known values). Parameters @@ -836,7 +867,7 @@ def evaluates_to_constant_number(self): return self.evaluates_to_number() and self.is_constant() @lru_cache - def evaluates_on_edges(self, dimension): + def evaluates_on_edges(self, dimension: str) -> bool: """ Returns True if a symbol evaluates on an edge, i.e. symbol contains a gradient operator, but not a divergence operator, and is not an IndefiniteIntegral. @@ -873,7 +904,14 @@ def has_symbol_of_classes(self, symbol_classes): """ return any(isinstance(symbol, symbol_classes) for symbol in self.pre_order()) - def to_casadi(self, t=None, y=None, y_dot=None, inputs=None, casadi_symbols=None): + def to_casadi( + self, + t: Optional[casadi.MX] = None, + y: Optional[casadi.MX] = None, + y_dot: Optional[casadi.MX] = None, + inputs: Optional[dict] = None, + casadi_symbols: Optional[Symbol] = None, + ): """ Convert the expression tree to a CasADi expression tree. See :class:`pybamm.CasadiConverter`. diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 7f9c45775c..b955b7cca4 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -1,6 +1,7 @@ # # Unary operator classes and methods # +from __future__ import annotations import numbers import numpy as np @@ -26,7 +27,7 @@ class UnaryOperator(pybamm.Symbol): child node """ - def __init__(self, name, child, domains=None): + def __init__(self, name: str, child: pybamm.Symbol, domains=None): if isinstance(child, numbers.Number): child = pybamm.Scalar(child) domains = domains or child.domains @@ -57,7 +58,13 @@ def _unary_evaluate(self, child): f"{self.__class__} does not implement _unary_evaluate." ) - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate( + self, + t: float = None, + y: np.array = None, + y_dot: np.array = None, + inputs: dict = None, + ): """See :meth:`pybamm.Symbol.evaluate()`.""" child = self.child.evaluate(t, y, y_dot, inputs) return self._unary_evaluate(child) @@ -69,7 +76,7 @@ def _evaluate_for_shape(self): """ return self.children[0].evaluate_for_shape() - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return self.child.evaluates_on_edges(dimension) @@ -103,7 +110,7 @@ def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "{}{!s}".format(self.name, self.child) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return -self.child.diff(variable) @@ -312,7 +319,7 @@ def _unary_new_copy(self, child): def _evaluate_for_shape(self): return self._unary_evaluate(self.children[0].evaluate_for_shape()) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -358,7 +365,7 @@ def __init__(self, child): ) super().__init__("grad", child) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return True @@ -393,7 +400,7 @@ def __init__(self, child): ) super().__init__("div", child) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -415,7 +422,7 @@ class Laplacian(SpatialOperator): def __init__(self, child): super().__init__("laplacian", child) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -431,7 +438,7 @@ class GradientSquared(SpatialOperator): def __init__(self, child): super().__init__("grad squared", child) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -573,7 +580,7 @@ def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -767,7 +774,7 @@ def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -800,7 +807,7 @@ def set_id(self): + tuple([(k, tuple(v)) for k, v in self.domains.items()]) ) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False @@ -956,7 +963,7 @@ def __init__(self, name, child): ) super().__init__(name, child) - def _evaluates_on_edges(self, dimension): + def _evaluates_on_edges(self, dimension: str) -> bool: """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return True @@ -989,7 +996,7 @@ def _unary_new_copy(self, child): """See :meth:`pybamm.Symbol.new_copy()`.""" return NotConstant(child) - def _diff(self, variable): + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.child.diff(variable) diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index f9f7d94efc..f16666bda1 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -1,11 +1,12 @@ # # Variable class # - +from __future__ import annotations import numpy as np import sympy import numbers import pybamm +from typing import Iterable, Union, Optional class VariableBase(pybamm.Symbol): @@ -49,14 +50,14 @@ class VariableBase(pybamm.Symbol): def __init__( self, - name, - domain=None, - auxiliary_domains=None, - domains=None, - bounds=None, - print_name=None, - scale=1, - reference=0, + name: str, + domain: Optional[Iterable[str]] = None, + auxiliary_domains: Optional[dict] = None, + domains: Optional[dict] = None, + bounds: Optional[tuple] = None, + print_name: Optional[str] = None, + scale: Optional[Union[float, pybamm.Symbol]] = 1, + reference: Optional[Union[float, pybamm.Symbol]] = 0, ): if isinstance(scale, numbers.Number): scale = pybamm.Scalar(scale) @@ -82,7 +83,7 @@ def bounds(self): return self._bounds @bounds.setter - def bounds(self, values): + def bounds(self, values: numbers.Number): if values is None: values = (-np.inf, np.inf) else: @@ -170,7 +171,7 @@ class Variable(VariableBase): Default is 0. """ - def diff(self, variable): + def diff(self, variable: pybamm.Symbol): if variable == self: return pybamm.Scalar(1) elif variable == pybamm.t: @@ -224,7 +225,7 @@ class VariableDot(VariableBase): Default is 0. """ - def get_variable(self): + def get_variable(self) -> pybamm.Variable: """ return a :class:`.Variable` corresponding to this VariableDot @@ -233,7 +234,7 @@ def get_variable(self): """ return Variable(self.name[:-1], domains=self.domains, scale=self.scale) - def diff(self, variable): + def diff(self, variable: pybamm.Variable) -> pybamm.Scalar: if variable == self: return pybamm.Scalar(1) elif variable == pybamm.t: diff --git a/pybamm/expression_tree/vector.py b/pybamm/expression_tree/vector.py index 758b988ca7..6e18fb2ee6 100644 --- a/pybamm/expression_tree/vector.py +++ b/pybamm/expression_tree/vector.py @@ -1,7 +1,9 @@ # # Vector class # +from __future__ import annotations import numpy as np +from typing import Union, Optional import pybamm @@ -13,13 +15,13 @@ class Vector(pybamm.Array): def __init__( self, - entries, - name=None, - domain=None, - auxiliary_domains=None, - domains=None, - entries_string=None, - ): + entries: Union[np.ndarray, list, np.matrix], + name: str = None, + domain: Optional[Union[list[str], str]] = None, + auxiliary_domains: Optional[dict[str, str]] = None, + domains: Optional[dict] = None, + entries_string: Optional[str] = None, + ) -> None: if isinstance(entries, (list, np.matrix)): entries = np.array(entries) # make sure that entries are a vector (can be a column vector) diff --git a/pybamm/models/event.py b/pybamm/models/event.py index e93262641d..3531355bd7 100644 --- a/pybamm/models/event.py +++ b/pybamm/models/event.py @@ -1,4 +1,5 @@ from enum import Enum +import numpy as np class EventType(Enum): @@ -46,7 +47,7 @@ def __init__(self, name, expression, event_type=EventType.TERMINATION): self._expression = expression self._event_type = event_type - def evaluate(self, t=None, y=None, y_dot=None, inputs=None): + def evaluate(self, t:float=None, y:np.array=None, y_dot:np.array=None, inputs:dict=None): """ Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate` """ From ea172e0d0fe757064bc5e39cfbf0640eba57c3cd Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Wed, 7 Jun 2023 14:21:26 +0000 Subject: [PATCH 02/32] (wip): Subset edits to typing after adding mypy --- mypy.ini | 34 ++++++++++++ pybamm/expression_tree/array.py | 26 ++++----- pybamm/expression_tree/averages.py | 23 ++++---- pybamm/expression_tree/binary_operators.py | 53 ++++++++++++------- pybamm/expression_tree/broadcasts.py | 24 ++++----- pybamm/expression_tree/concatenations.py | 32 ++++++----- pybamm/expression_tree/functions.py | 10 ++-- .../expression_tree/independent_variable.py | 27 +++++----- pybamm/expression_tree/input_parameter.py | 14 ++--- pybamm/expression_tree/matrix.py | 12 ++--- .../operations/convert_to_casadi.py | 2 +- .../operations/evaluate_python.py | 11 ++-- pybamm/expression_tree/operations/jacobian.py | 2 +- .../operations/unpack_symbols.py | 16 +++--- pybamm/expression_tree/parameter.py | 10 ++-- pybamm/expression_tree/scalar.py | 19 ++++--- pybamm/expression_tree/state_vector.py | 26 ++++----- pybamm/expression_tree/symbol.py | 48 +++++++++-------- pybamm/expression_tree/unary_operators.py | 11 ++-- pybamm/expression_tree/variable.py | 4 +- pybamm/expression_tree/vector.py | 2 +- pybamm/install_odes.py | 2 +- pybamm/meshes/scikit_fem_submeshes.py | 2 +- pybamm/models/event.py | 10 +++- pybamm/parameters/bpx.py | 2 +- pybamm/parameters/parameter_sets.py | 3 +- pybamm/parameters/parameter_values.py | 4 +- 27 files changed, 252 insertions(+), 177 deletions(-) create mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000..cb47be594a --- /dev/null +++ b/mypy.ini @@ -0,0 +1,34 @@ +[mypy] + +[mypy-scipy.*] +ignore_missing_imports=True + +[mypy-casadi.*] +ignore_missing_imports=True + +[mypy-matplotlib.*] +ignore_missing_imports=True + +[mypy-pandas.*] +ignore_missing_imports=True + +[mypy-pybtex.*] +ignore_missing_imports=True + +[mypy-ipywidgets.*] +ignore_missing_imports=True + +[mypy-anytree.*] +ignore_missing_imports=True + +[mypy-pkg_resources.*] +ignore_missing_imports=True + +[mypy-tqdm.*] +ignore_missing_imports=True + +[mypy-skfem.*] +ignore_missing_imports=True + +[mypy-absl.*] +ignore_missing_imports=True \ No newline at end of file diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 73d04e19aa..698d72a568 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -5,7 +5,7 @@ import numpy as np import sympy from scipy.sparse import csr_matrix, issparse -from typing import Union +from typing import Union, Tuple, Optional import pybamm @@ -38,12 +38,12 @@ class Array(pybamm.Symbol): def __init__( self, - entries: Union[np.array, list], - name: str = None, - domain: list[str] = None, - auxiliary_domains: dict[str, str] = None, - domains: dict = None, - entries_string: str = None, + entries: Union[np.ndarray, list], + name: Optional[str] = None, + domain: Union[list[str], str, None] = None, + auxiliary_domains: Optional[dict[str, str]] = None, + domains: Optional[dict] = None, + entries_string: Optional[str] = None, ) -> None: # if if isinstance(entries, list): @@ -119,10 +119,10 @@ def create_copy(self): def _base_evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" return self._entries @@ -146,7 +146,9 @@ def linspace(start: float, stop: float, num=50, **kwargs) -> pybamm.Array: return pybamm.Array(np.linspace(start, stop, num, **kwargs)) -def meshgrid(x, y, **kwargs): +def meshgrid( + x: pybamm.Array, y: pybamm.Array, **kwargs +) -> Tuple[pybamm.Array, pybamm.Array]: """ Return coordinate matrices as from coordinate vectors by calling `numpy.meshgrid` with keyword arguments 'kwargs'. For a list of 'kwargs' diff --git a/pybamm/expression_tree/averages.py b/pybamm/expression_tree/averages.py index 5bebc76592..95b58ab48b 100644 --- a/pybamm/expression_tree/averages.py +++ b/pybamm/expression_tree/averages.py @@ -1,7 +1,7 @@ # # Classes and methods for averaging # -from typing import Union +from typing import Union, Callable import pybamm @@ -15,13 +15,15 @@ class _BaseAverage(pybamm.Integral): The child node """ - def __init__(self, child: pybamm.Symbol, name: str, integration_variable): + def __init__( + self, child: pybamm.Symbol, name: str, integration_variable: list + ) -> None: super().__init__(child, integration_variable) self.name = name class XAverage(_BaseAverage): - def __init__(self, child: pybamm.Symbol): + def __init__(self, child: pybamm.Symbol) -> None: if all(n in child.domain[0] for n in ["negative", "particle"]): x = pybamm.standard_spatial_vars.x_n elif all(n in child.domain[0] for n in ["positive", "particle"]): @@ -37,7 +39,7 @@ def _unary_new_copy(self, child: pybamm.Symbol): class YZAverage(_BaseAverage): - def __init__(self, child: pybamm.Symbol): + def __init__(self, child: pybamm.Symbol) -> None: y = pybamm.standard_spatial_vars.y z = pybamm.standard_spatial_vars.z integration_variable = [y, z] @@ -49,7 +51,7 @@ def _unary_new_copy(self, child: pybamm.Symbol): class ZAverage(_BaseAverage): - def __init__(self, child: pybamm.Symbol): + def __init__(self, child: pybamm.Symbol) -> None: integration_variable = [pybamm.standard_spatial_vars.z] super().__init__(child, "z-average", integration_variable) @@ -59,7 +61,7 @@ def _unary_new_copy(self, child: pybamm.Symbol): class RAverage(_BaseAverage): - def __init__(self, child: pybamm.Symbol): + def __init__(self, child: pybamm.Symbol) -> None: integration_variable = [pybamm.SpatialVariable("r", child.domain)] super().__init__(child, "r-average", integration_variable) @@ -69,7 +71,7 @@ def _unary_new_copy(self, child: pybamm.Symbol): class SizeAverage(_BaseAverage): - def __init__(self, child: pybamm.Symbol, f_a_dist): + def __init__(self, child: pybamm.Symbol, f_a_dist) -> None: R = pybamm.SpatialVariable("R", domains=child.domains, coord_sys="cartesian") integration_variable = [R] super().__init__(child, "size-average", integration_variable) @@ -199,7 +201,7 @@ def z_average(symbol: pybamm.Symbol) -> pybamm.Symbol: return symbol # If symbol is a Broadcast, its average value is its child elif isinstance(symbol, pybamm.Broadcast): - return symbol.reduce_one_dimension() + return symbol.reduce_one_dimension() # type:ignore # Average of a sum is sum of averages elif isinstance(symbol, (pybamm.Addition, pybamm.Subtraction)): return _sum_of_averages(symbol, z_average) @@ -235,7 +237,7 @@ def yz_average(symbol: pybamm.Symbol) -> pybamm.Symbol: return symbol # If symbol is a Broadcast, its average value is its child elif isinstance(symbol, pybamm.Broadcast): - return symbol.reduce_one_dimension() + return symbol.reduce_one_dimension() # type:ignore # Average of a sum is sum of averages elif isinstance(symbol, (pybamm.Addition, pybamm.Subtraction)): return _sum_of_averages(symbol, yz_average) @@ -345,7 +347,8 @@ def size_average(symbol: pybamm.Symbol, f_a_dist=None) -> pybamm.Symbol: def _sum_of_averages( - symbol: Union[pybamm.Addition, pybamm.Subtraction], average_function + symbol: Union[pybamm.Addition, pybamm.Subtraction], + average_function: Callable[[pybamm.Symbol], pybamm.Symbol], ): if isinstance(symbol, pybamm.Addition): return average_function(symbol.left) + average_function(symbol.right) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 5391d31c88..79b512b33a 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -8,7 +8,7 @@ import sympy from scipy.sparse import csr_matrix, issparse import functools -from typing import Union, Tuple +from typing import Union, Tuple, Optional import pybamm @@ -16,7 +16,7 @@ def _preprocess_binary( left: Union[numbers.Number, pybamm.Symbol], right: Union[numbers.Number, pybamm.Symbol], -) -> Tuple[pybamm.PrimaryBroadcast, pybamm.PrimaryBroadcast]: +) -> Tuple[pybamm.Symbol, pybamm.Symbol]: if isinstance(left, numbers.Number): left = pybamm.Scalar(left) if isinstance(right, numbers.Number): @@ -70,7 +70,7 @@ def __init__( name: str, left: Union[numbers.Number, pybamm.Symbol], right: Union[numbers.Number, pybamm.Symbol], - ): + ) -> None: left, right = _preprocess_binary(left, right) domains = self.get_children_domains([left, right]) @@ -125,10 +125,10 @@ def _binary_new_copy( def evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" left = self.left.evaluate(t, y, y_dot, inputs) @@ -210,7 +210,11 @@ def _binary_jac(self, left_jac, right_jac): right * left_jac + left * pybamm.log(left) * right_jac ) - def _binary_evaluate(self, left, right): + def _binary_evaluate( + self, + left: Union[float, np.ndarray, pybamm.Symbol], + right: Union[float, np.ndarray, pybamm.Symbol], + ): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" # don't raise RuntimeWarning for NaNs with np.errstate(invalid="ignore"): @@ -222,7 +226,7 @@ class Addition(BinaryOperator): A node in the expression tree representing an addition operator. """ - def __init__(self, left, right): + def __init__(self, left: pybamm.Symbol, right: pybamm.Symbol): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("+", left, right) @@ -230,11 +234,15 @@ def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) + self.right.diff(variable) - def _binary_jac(self, left_jac, right_jac): + def _binary_jac( + self, left_jac: Union[float, np.ndarray], right_jac: Union[float, np.ndarray] + ): """See :meth:`pybamm.BinaryOperator._binary_jac()`.""" return left_jac + right_jac - def _binary_evaluate(self, left, right): + def _binary_evaluate( + self, left: Union[float, np.ndarray], right: Union[float, np.ndarray] + ): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" return left + right @@ -244,7 +252,7 @@ class Subtraction(BinaryOperator): A node in the expression tree representing a subtraction operator. """ - def __init__(self, left, right): + def __init__(self, left: pybamm.Symbol, right: pybamm.Symbol): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("-", left, right) @@ -257,7 +265,9 @@ def _binary_jac(self, left_jac, right_jac): """See :meth:`pybamm.BinaryOperator._binary_jac()`.""" return left_jac - right_jac - def _binary_evaluate(self, left, right): + def _binary_evaluate( + self, left: Union[float, np.ndarray], right: Union[float, np.ndarray] + ): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" return left - right @@ -704,10 +714,10 @@ def _sympy_operator(self, left, right): def _simplify_elementwise_binary_broadcasts( left: Union[numbers.Number, pybamm.Symbol], right: Union[numbers.Number, pybamm.Symbol], -): +) -> Tuple[pybamm.Symbol, pybamm.Symbol]: left, right = _preprocess_binary(left, right) - def unpack_broadcast_recursive(symbol: pybamm.Symbol): + def unpack_broadcast_recursive(symbol: pybamm.Symbol) -> pybamm.Symbol: if isinstance(symbol, pybamm.Broadcast): if symbol.child.domain == []: return symbol.orphans[0] @@ -733,10 +743,10 @@ def unpack_broadcast_recursive(symbol: pybamm.Symbol): def _simplified_binary_broadcast_concatenation( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: pybamm.Symbol, # Union[numbers.Number, pybamm.Symbol] + right: pybamm.Symbol, # Union[numbers.Number, pybamm.Symbol] operator, -): +) -> Union[None, pybamm.Broadcast]: """ Check if there are concatenations or broadcasts that we can commute the operator with @@ -774,6 +784,7 @@ def _simplified_binary_broadcast_concatenation( return right._concatenation_new_copy( [operator(left, child) for child in right.orphans] ) + return None def simplified_power( @@ -822,8 +833,8 @@ def simplified_power( def add( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: pybamm.Symbol, + right: pybamm.Symbol, ): """ Note @@ -1003,6 +1014,8 @@ def multiply( ): left, right = _simplify_elementwise_binary_broadcasts(left, right) + assert isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol) + # Move constant to always be on the left if right.is_constant() and not left.is_constant(): left, right = right, left diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index da0a05187b..ea3df9654a 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -5,7 +5,7 @@ import numpy as np from scipy.sparse import csr_matrix -from typing import Iterable, Optional +from typing import Sequence, Optional, Union import pybamm @@ -31,7 +31,7 @@ class Broadcast(pybamm.SpatialOperator): """ def __init__( - self, child: pybamm.Symbol, domains: Iterable[str], name: Optional[str] = None + self, child: pybamm.Symbol, domains: Sequence[str], name: Optional[str] = None ): if name is None: name = "broadcast" @@ -75,8 +75,8 @@ class PrimaryBroadcast(Broadcast): def __init__( self, - child: pybamm.Symbol, - broadcast_domain: Iterable[str], + child: Union[numbers.Number, pybamm.Symbol], + broadcast_domain: Sequence[str], name: Optional[str] = None, ): # Convert child to scalar if it is a number @@ -92,7 +92,7 @@ def __init__( super().__init__(child, domains, name=name) def check_and_set_domains( - self, child: pybamm.Symbol, broadcast_domain: Iterable[str] + self, child: pybamm.Symbol, broadcast_domain: Sequence[str] ): """See :meth:`Broadcast.check_and_set_domains`""" # Can only do primary broadcast from current collector to electrode, @@ -172,7 +172,7 @@ class PrimaryBroadcastToEdges(PrimaryBroadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Iterable[str], + broadcast_domain: Sequence[str], name: Optional[str] = None, ): name = name or "broadcast to edges" @@ -208,7 +208,7 @@ class SecondaryBroadcast(Broadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Iterable[str], + broadcast_domain: Sequence[str], name: Optional[str] = None, ): # Convert domain to list if it's a string @@ -221,7 +221,7 @@ def __init__( super().__init__(child, domains, name=name) def check_and_set_domains( - self, child: pybamm.Symbol, broadcast_domain: Iterable[str] + self, child: pybamm.Symbol, broadcast_domain: Sequence[str] ): """See :meth:`Broadcast.check_and_set_domains`""" if child.domain == []: @@ -308,7 +308,7 @@ class SecondaryBroadcastToEdges(SecondaryBroadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Iterable[str], + broadcast_domain: Sequence[str], name: Optional[str] = None, ): name = name or "broadcast to edges" @@ -344,7 +344,7 @@ class TertiaryBroadcast(Broadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Iterable[str], + broadcast_domain: Sequence[str], name: Optional[str] = None, ): # Convert domain to list if it's a string @@ -357,7 +357,7 @@ def __init__( super().__init__(child, domains, name=name) def check_and_set_domains( - self, child: pybamm.Symbol, broadcast_domain: Iterable[str] + self, child: pybamm.Symbol, broadcast_domain: Sequence[str] ): """See :meth:`Broadcast.check_and_set_domains`""" if child.domains["secondary"] == []: @@ -429,7 +429,7 @@ class TertiaryBroadcastToEdges(TertiaryBroadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Iterable[str], + broadcast_domain: Sequence[str], name: Optional[str] = None, ): name = name or "broadcast to edges" diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 1f8a90ff52..693f9c8386 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -1,16 +1,14 @@ # # Concatenation classes # +from __future__ import annotations import copy from collections import defaultdict import numpy as np import sympy from scipy.sparse import issparse, vstack -from typing import Optional, Iterable, TYPE_CHECKING - -if TYPE_CHECKING: - from pybamm import Concatenation, DomainConcatenation +from typing import Optional, Iterable, Sequence, TYPE_CHECKING import pybamm @@ -30,7 +28,7 @@ def __init__( *children: Iterable[pybamm.Symbol], name=None, check_domain=True, - concat_fun=None + concat_fun=None, ): # The second condition checks whether this is the base Concatenation class # or a subclass of Concatenation @@ -71,7 +69,7 @@ def _diff(self, variable: pybamm.Symbol): return diff - def get_children_domains(self, children: Iterable[pybamm.Symbol]): + def get_children_domains(self, children: Sequence[pybamm.Symbol]): # combine domains from children domain = [] for child in children: @@ -109,10 +107,10 @@ def _concatenation_evaluate(self, children_eval): def evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" children = self.children @@ -184,18 +182,18 @@ class NumpyConcatenation(Concatenation): The equations to concatenate """ - def __init__(self, *children: Iterable[pybamm.Symbol]): + def __init__(self, *children: Sequence[pybamm.Symbol]): children = list(children) # Turn objects that evaluate to scalars to objects that evaluate to vectors, # so that we can concatenate them for i, child in enumerate(children): - if child.evaluates_to_number(): + if child.evaluates_to_number(): # type:ignore children[i] = child * pybamm.Vector([1]) super().__init__( *children, name="numpy_concatenation", check_domain=False, - concat_fun=np.concatenate + concat_fun=np.concatenate, ) def _concatenation_jac(self, children_jacs): @@ -225,7 +223,7 @@ class DomainConcatenation(Concatenation): children : iterable of :class:`pybamm.Symbol` The symbols to concatenate - full_mesh : :class:`pybamm.BaseMesh` + full_mesh : :class:`pybamm.Mesh` The underlying mesh for discretisation, used to obtain the number of mesh points in each domain. @@ -237,8 +235,8 @@ class DomainConcatenation(Concatenation): def __init__( self, children: Iterable[pybamm.Symbol], - full_mesh, # pybamm.BaseMesh - copy_this=None, #: Optional[pybamm.DomainConcatenation] + full_mesh: pybamm.Mesh, + copy_this: Optional[pybamm.DomainConcatenation] = None, ): # Convert any constant symbols in children to a Vector of the right size for # concatenation @@ -361,7 +359,7 @@ def __init__(self, *children): #: Iterable[pybamm.Concatenation] *children, name="sparse_stack", check_domain=False, - concat_fun=concatenation_function + concat_fun=concatenation_function, ) def _concatenation_new_copy(self, children): diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index beb8e5fcc9..4c67e16ba0 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -4,7 +4,7 @@ from __future__ import annotations import numbers -import autograd +import autograd # type: ignore import numpy as np import sympy from scipy import special @@ -145,10 +145,10 @@ def _function_jac(self, children_jacs): def evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" evaluated_children = [ diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 20c05b84c4..9bb7684b82 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -3,6 +3,7 @@ # import sympy import numpy as np +from typing import Union, Optional import pybamm @@ -33,9 +34,9 @@ class IndependentVariable(pybamm.Symbol): def __init__( self, name: str, - domain: list[str] = None, - auxiliary_domains: dict = None, - domains: dict = None, + domain: Optional[list[str]] = None, + auxiliary_domains: Optional[dict] = None, + domains: Optional[dict] = None, ) -> None: super().__init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains @@ -71,10 +72,10 @@ def create_copy(self): def _base_evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if t is None: @@ -116,9 +117,9 @@ class SpatialVariable(IndependentVariable): def __init__( self, name: str, - domain: list[str] = None, - auxiliary_domains: dict = None, - domains: dict = None, + domain: Union[list[str], str, None] = None, + auxiliary_domains: Optional[dict] = None, + domains: Optional[dict] = None, coord_sys=None, ) -> None: self.coord_sys = coord_sys @@ -179,9 +180,9 @@ class SpatialVariableEdge(SpatialVariable): def __init__( self, name: str, - domain: list[str] = None, - auxiliary_domains: dict = None, - domains: dict = None, + domain: Union[list[str], str, None] = None, + auxiliary_domains: Optional[dict] = None, + domains: Optional[dict] = None, coord_sys=None, ) -> None: super().__init__(name, domain, auxiliary_domains, domains, coord_sys) diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index f0c81fee68..ab60892ad0 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -7,7 +7,7 @@ import scipy.sparse import pybamm -from typing import Union, Iterable +from typing import Union, Iterable, Optional class InputParameter(pybamm.Symbol): @@ -31,8 +31,8 @@ class InputParameter(pybamm.Symbol): def __init__( self, name: str, - domain: Union[Iterable[str], str] = None, - expected_size: int = None, + domain: Optional[Union[Iterable[str], str]] = None, + expected_size: Optional[int] = None, ) -> None: # Expected size defaults to 1 if no domain else None (gets set later) if expected_size is None: @@ -75,10 +75,10 @@ def _jac(self, variable: pybamm.Variable) -> pybamm.Matrix: def _base_evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): # inputs should be a dictionary # convert 'None' to empty dictionary for more informative error diff --git a/pybamm/expression_tree/matrix.py b/pybamm/expression_tree/matrix.py index 0c23e059e3..68225dd9a8 100644 --- a/pybamm/expression_tree/matrix.py +++ b/pybamm/expression_tree/matrix.py @@ -3,7 +3,7 @@ # import numpy as np from scipy.sparse import csr_matrix, issparse -from typing import Union +from typing import Union, Optional import pybamm @@ -16,11 +16,11 @@ class Matrix(pybamm.Array): def __init__( self, entries: Union[np.ndarray, list], - name: str = None, - domain: list[str] = None, - auxiliary_domains: dict[str, str] = None, - domains: dict = None, - entries_string: str = None, + name: Optional[str] = None, + domain: Optional[list[str]] = None, + auxiliary_domains: Optional[dict[str, str]] = None, + domains: Optional[dict] = None, + entries_string: Optional[str] = None, ) -> None: if isinstance(entries, list): entries = np.array(entries) diff --git a/pybamm/expression_tree/operations/convert_to_casadi.py b/pybamm/expression_tree/operations/convert_to_casadi.py index 2aa1bfa720..91202fc18e 100644 --- a/pybamm/expression_tree/operations/convert_to_casadi.py +++ b/pybamm/expression_tree/operations/convert_to_casadi.py @@ -19,7 +19,7 @@ def convert( t: casadi.MX, y: casadi.MX, y_dot: casadi.MX, - inputs: dict[casadi.MX], + inputs: dict, ) -> casadi.MX: """ This function recurses down the tree, converting the PyBaMM expression tree to diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 7560f950fe..00ffeb3e23 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -13,6 +13,7 @@ if pybamm.have_jax(): import jax + import jaxlib from jax.config import config config.update("jax_enable_x64", True) @@ -59,7 +60,7 @@ def toarray(self): result = jax.numpy.zeros(self.shape, dtype=self.data.dtype) return result.at[self.row, self.col].add(self.data) - def dot_product(self, b: ArrayLike): + def dot_product(self, b: jaxlib.xla_extension.DeviceArray): """ dot product of matrix with a dense column vector b @@ -72,7 +73,7 @@ def dot_product(self, b: ArrayLike): result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype) return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col]) - def scalar_multiply(self, b: ArrayLike): + def scalar_multiply(self, b: float): """ multiply of matrix with a scalar b @@ -387,7 +388,7 @@ def find_symbols( def to_python( symbol: pybamm.Symbol, debug=False, output_jax=False -) -> Tuple[OrderedDict, str, bool]: +) -> Tuple[OrderedDict, str]: """ This function converts an expression tree into a dict of constant input values, and valid python code that acts like the tree's :func:`pybamm.Symbol.evaluate` function @@ -413,8 +414,8 @@ def to_python( operations are used """ - constant_values = OrderedDict() - variable_symbols = OrderedDict() + constant_values: OrderedDict = OrderedDict() + variable_symbols: OrderedDict = OrderedDict() find_symbols(symbol, constant_values, variable_symbols, output_jax) line_format = "{} = {}" diff --git a/pybamm/expression_tree/operations/jacobian.py b/pybamm/expression_tree/operations/jacobian.py index 43d314db6e..c3de08c21c 100644 --- a/pybamm/expression_tree/operations/jacobian.py +++ b/pybamm/expression_tree/operations/jacobian.py @@ -22,7 +22,7 @@ class Jacobian(object): def __init__( self, known_jacs: Optional[dict[str, pybamm.Symbol]] = None, - clear_domain: Optional[bool] = True, + clear_domain: bool = True, ): self._known_jacs = known_jacs or {} self._clear_domain = clear_domain diff --git a/pybamm/expression_tree/operations/unpack_symbols.py b/pybamm/expression_tree/operations/unpack_symbols.py index 9868e55c48..61086e8a04 100644 --- a/pybamm/expression_tree/operations/unpack_symbols.py +++ b/pybamm/expression_tree/operations/unpack_symbols.py @@ -2,7 +2,7 @@ # Helper function to unpack a symbol # from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Iterable, Union, Sequence if TYPE_CHECKING: import pybamm @@ -21,13 +21,17 @@ class SymbolUnpacker(object): cached unpacked equations """ - def __init__(self, classes_to_find, unpacked_symbols: Optional[set] = None): + def __init__( + self, + classes_to_find: Union[pybamm.Symbol, Sequence[pybamm.Symbol]], + unpacked_symbols: Optional[set] = None, + ): self.classes_to_find = classes_to_find - self._unpacked_symbols = unpacked_symbols or {} + self._unpacked_symbols: set = unpacked_symbols or {} def unpack_list_of_symbols( - self, list_of_symbols: list[pybamm.Symbol] - ) -> list[pybamm.Symbol]: + self, list_of_symbols: Sequence[pybamm.Symbol] + ) -> set[pybamm.Symbol]: """ Unpack a list of symbols. See :meth:`SymbolUnpacker.unpack()` @@ -48,7 +52,7 @@ def unpack_list_of_symbols( return all_instances - def unpack_symbol(self, symbol: list[pybamm.Symbol]) -> list[pybamm.Symbol]: + def unpack_symbol(self, symbol: Sequence[pybamm.Symbol]) -> list[pybamm.Symbol]: """ This function recurses down the tree, unpacking the symbols and saving the ones that have a class in `self.classes_to_find`. diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index 97542b28c2..8cb60aec36 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -7,10 +7,10 @@ import numpy as np import sympy -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, Literal -if TYPE_CHECKING: - from pybamm import FunctionParameter +# if TYPE_CHECKING: +# from pybamm import FunctionParameter import pybamm @@ -42,7 +42,7 @@ def _evaluate_for_shape(self) -> np.nan: """ return np.nan - def is_constant(self) -> False: + def is_constant(self) -> Literal[False]: """See :meth:`pybamm.Symbol.is_constant()`.""" # Parameter is not constant since it can become an InputParameter return False @@ -134,7 +134,7 @@ def print_input_names(self): print(inp) @input_names.setter - def input_names(self, inp: dict[str, pybamm.Symbol] = None): + def input_names(self, inp=None): if inp: if inp.__class__ is list: for i in inp: diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index cae9e9c3c4..6e69745743 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -2,9 +2,10 @@ # Scalar class # from __future__ import annotations +import numbers import numpy as np import sympy -from typing import Optional +from typing import Optional, Literal, Union import pybamm @@ -23,7 +24,11 @@ class Scalar(pybamm.Symbol): """ - def __init__(self, value: float, name: Optional[str] = None) -> None: + def __init__( + self, + value: Union[pybamm.Scalar, numbers.Number, float], + name: Optional[str] = None, + ) -> None: # set default name if not provided self.value = value if name is None: @@ -56,10 +61,10 @@ def set_id(self): def _base_evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" return self._value @@ -72,7 +77,7 @@ def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" return Scalar(self.value, self.name) - def is_constant(self) -> True: + def is_constant(self) -> Literal[True]: """See :meth:`pybamm.Symbol.is_constant()`.""" return True diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index fd4cc016d2..897474aa6d 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -39,7 +39,7 @@ def __init__( *y_slices: slice, base_name="y", name: Optional[str] = None, - domain: Iterable[str] = None, + domain: Optional[Iterable[str]] = None, auxiliary_domains: Optional[dict[str]] = None, domains: Optional[dict] = None, evaluation_array: Optional[list] = None, @@ -113,7 +113,7 @@ def set_id(self): + tuple(self.domain) ) - def _jac_diff_vector(self, variable: pybamm.Symbol): + def _jac_diff_vector(self, variable: pybamm.StateVectorBase): """ Differentiate a slice of a StateVector of size m with respect to another slice of a different StateVector of size n. This returns a (sparse) zero matrix of @@ -134,7 +134,7 @@ def _jac_diff_vector(self, variable: pybamm.Symbol): # Return zeros of correct size since no entries match return pybamm.Matrix(csr_matrix((slices_size, variable_size))) - def _jac_same_vector(self, variable: pybamm.Symbol): + def _jac_same_vector(self, variable: pybamm.StateVectorBase): """ Differentiate a slice of a StateVector of size m with respect to another slice of a StateVector of size n. This returns a (sparse) matrix of size @@ -226,7 +226,7 @@ def __init__( self, *y_slices: slice, name: Optional[str] = None, - domain: Iterable[str] = None, + domain: Optional[Iterable[str]] = None, auxiliary_domains: Optional[dict[str]] = None, domains: Optional[dict] = None, evaluation_array: Optional[list] = None, @@ -243,10 +243,10 @@ def __init__( def _base_evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if y is None: @@ -310,7 +310,7 @@ def __init__( self, *y_slices: slice, name: Optional[str] = None, - domain: Iterable[str] = None, + domain: Optional[Iterable[str]] = None, auxiliary_domains: Optional[dict[str]] = None, domains: Optional[dict] = None, evaluation_array: Optional[list] = None, @@ -327,10 +327,10 @@ def __init__( def _base_evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if y_dot is None: diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 39360204c3..024c81e119 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -10,7 +10,7 @@ from anytree.exporter import DotExporter from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property -from typing import Union, TYPE_CHECKING, Optional +from typing import Union, TYPE_CHECKING, Optional, Iterable, Sequence import pybamm from pybamm.expression_tree.printing.print_name import prettify_print_name @@ -157,8 +157,8 @@ def is_matrix_minus_one(expr: Symbol): def simplify_if_constant( - symbol, -): # division, Negate (unary operator), Maximum (binary), multiplication, addition, subtraction + symbol: Symbol, +) -> Symbol: """ Utility function to simplify an expression tree if it evalutes to a constant scalar, vector or matrix @@ -216,10 +216,10 @@ class Symbol: def __init__( self, name: str, - children: list[Symbol] = None, - domain: Union[list[str], str] = None, - auxiliary_domains: dict[str, str] = None, - domains: dict = None, + children: Optional[list[Symbol]] = None, + domain: Optional[Union[list[str], str]] = None, + auxiliary_domains: Optional[dict[str, str]] = None, + domains: Optional[dict] = None, ): super(Symbol, self).__init__() self.name = name @@ -296,7 +296,7 @@ def auxiliary_domains(self): ) @domains.setter - def domains(self, domains: dict): + def domains(self, domains): try: if ( self._domains == domains @@ -367,9 +367,9 @@ def clear_domains(self): self._domains = EMPTY_DOMAINS self.set_id() - def get_children_domains(self, children: list[Symbol]): + def get_children_domains(self, children: Iterable[Symbol]): """Combine domains from children, at all levels.""" - domains = {} + domains: dict = {} for child in children: for level in child.domains.keys(): if child.domains[level] == []: @@ -390,9 +390,9 @@ def get_children_domains(self, children: list[Symbol]): def read_domain_or_domains( self, - domain: Union[list[str], str], - auxiliary_domains: dict[str, str], - domains: dict, + domain: Optional[Union[list[str], str]], + auxiliary_domains: Optional[dict[str, str]], + domains: Optional[dict], ): if domains is None: if isinstance(domain, str): @@ -436,7 +436,7 @@ def scale(self): def reference(self): return self._reference - def __eq__(self, other: Symbol): + def __eq__(self, other): try: return self._id == other._id except AttributeError: @@ -596,7 +596,7 @@ def __rtruediv__(self, other: Symbol) -> Division: """return a :class:`Division` object.""" return pybamm.divide(other, self) - def __pow__(self, other: Symbol) -> pybamm.Power: + def __pow__(self, other: Union[Symbol, float]) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(self, other) @@ -604,7 +604,7 @@ def __rpow__(self, other: Symbol) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(other, self) - def __lt__(self, other: Symbol) -> pybamm.NotEqualHeaviside: + def __lt__(self, other: Union[Symbol, float]) -> pybamm.NotEqualHeaviside: """return a :class:`NotEqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(self, other, False) @@ -620,7 +620,7 @@ def __ge__(self, other: Symbol) -> pybamm.EqualHeaviside: """return a :class:`EqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(other, self, True) - def __neg__(self) -> pybamm.Negate: + def __neg__(self) -> pybamm.Symbol: """return a :class:`Negate` object.""" if isinstance(self, pybamm.Negate): # Double negative is a positive @@ -730,10 +730,10 @@ def _jac(self, variable): def _base_evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """ evaluate expression tree. @@ -761,7 +761,11 @@ def _base_evaluate( ) def evaluate( - self, t=None, y: np.array = None, y_dot: np.array = None, inputs: dict = None + self, + t: Optional[numbers.Number] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """Evaluate expression tree (wrapper to allow using dict of known values). diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index b955b7cca4..f493e55211 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -3,6 +3,7 @@ # from __future__ import annotations import numbers +from typing import Optional import numpy as np import sympy @@ -60,10 +61,10 @@ def _unary_evaluate(self, child): def evaluate( self, - t: float = None, - y: np.array = None, - y_dot: np.array = None, - inputs: dict = None, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" child = self.child.evaluate(t, y, y_dot, inputs) @@ -343,7 +344,7 @@ class with a :class:`Matrix` child node """ - def __init__(self, name, child, domains=None): + def __init__(self, name: str, child: pybamm.Symbol, domains=None): super().__init__(name, child, domains) diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index f16666bda1..d65358a97a 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -6,7 +6,7 @@ import sympy import numbers import pybamm -from typing import Iterable, Union, Optional +from typing import Iterable, Union, Optional, Sequence class VariableBase(pybamm.Symbol): @@ -83,7 +83,7 @@ def bounds(self): return self._bounds @bounds.setter - def bounds(self, values: numbers.Number): + def bounds(self, values: Sequence[numbers.Number]): if values is None: values = (-np.inf, np.inf) else: diff --git a/pybamm/expression_tree/vector.py b/pybamm/expression_tree/vector.py index 6e18fb2ee6..4070b97f6f 100644 --- a/pybamm/expression_tree/vector.py +++ b/pybamm/expression_tree/vector.py @@ -16,7 +16,7 @@ class Vector(pybamm.Array): def __init__( self, entries: Union[np.ndarray, list, np.matrix], - name: str = None, + name: Optional[str] = None, domain: Optional[Union[list[str], str]] = None, auxiliary_domains: Optional[dict[str, str]] = None, domains: Optional[dict] = None, diff --git a/pybamm/install_odes.py b/pybamm/install_odes.py index 4bf310a0f2..f06ef541c5 100644 --- a/pybamm/install_odes.py +++ b/pybamm/install_odes.py @@ -10,7 +10,7 @@ try: # wget module is required to download SUNDIALS or SuiteSparse. - import wget + import wget # type: ignore NO_WGET = False except ModuleNotFoundError: diff --git a/pybamm/meshes/scikit_fem_submeshes.py b/pybamm/meshes/scikit_fem_submeshes.py index f25dce80b1..ecb2655aa0 100644 --- a/pybamm/meshes/scikit_fem_submeshes.py +++ b/pybamm/meshes/scikit_fem_submeshes.py @@ -4,7 +4,7 @@ import pybamm from .meshes import SubMesh -import skfem +import skfem # type: ignore import numpy as np diff --git a/pybamm/models/event.py b/pybamm/models/event.py index 3531355bd7..728c9c232b 100644 --- a/pybamm/models/event.py +++ b/pybamm/models/event.py @@ -1,6 +1,8 @@ from enum import Enum import numpy as np +from typing import Optional + class EventType(Enum): """ @@ -47,7 +49,13 @@ def __init__(self, name, expression, event_type=EventType.TERMINATION): self._expression = expression self._event_type = event_type - def evaluate(self, t:float=None, y:np.array=None, y_dot:np.array=None, inputs:dict=None): + def evaluate( + self, + t: Optional[float] = None, + y: Optional[np.ndarray] = None, + y_dot: Optional[np.ndarray] = None, + inputs: Optional[dict] = None, + ): """ Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate` """ diff --git a/pybamm/parameters/bpx.py b/pybamm/parameters/bpx.py index 8f555fa9a8..cb10a87ede 100644 --- a/pybamm/parameters/bpx.py +++ b/pybamm/parameters/bpx.py @@ -1,4 +1,4 @@ -from bpx import BPX, Function, InterpolatedTable +from bpx import BPX, Function, InterpolatedTable # type: ignore import pybamm import math from dataclasses import dataclass diff --git a/pybamm/parameters/parameter_sets.py b/pybamm/parameters/parameter_sets.py index 1da7f239dd..37b0b72158 100644 --- a/pybamm/parameters/parameter_sets.py +++ b/pybamm/parameters/parameter_sets.py @@ -2,6 +2,7 @@ import importlib.metadata import textwrap from collections.abc import Mapping +from typing import Callable class ParameterSets(Mapping): @@ -49,7 +50,7 @@ def __new__(cls): def __getitem__(self, key) -> dict: return self.__load_entry_point__(key)() - def __load_entry_point__(self, key) -> callable: + def __load_entry_point__(self, key) -> Callable: """Check that ``key`` is a registered ``pybamm_parameter_sets``, and return the entry point for the parameter set, loading it needed. """ diff --git a/pybamm/parameters/parameter_values.py b/pybamm/parameters/parameter_values.py index 136d9737aa..e03be6ae29 100644 --- a/pybamm/parameters/parameter_values.py +++ b/pybamm/parameters/parameter_values.py @@ -94,8 +94,8 @@ def create_from_bpx(filename, target_soc=1): if target_soc < 0 or target_soc > 1: raise ValueError("Target SOC should be between 0 and 1") - from bpx import parse_bpx_file, get_electrode_concentrations - from .bpx import _bpx_to_param_dict + from bpx import parse_bpx_file, get_electrode_concentrations # type: ignore + from .bpx import _bpx_to_param_dict # type: ignore # parse bpx bpx = parse_bpx_file(filename) From f5dd303e37574599258307bf944485d713e1648a Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 8 Jun 2023 12:06:33 +0000 Subject: [PATCH 03/32] (wip): more mypy edits --- pybamm/expression_tree/array.py | 8 +-- pybamm/expression_tree/averages.py | 10 +++- pybamm/expression_tree/binary_operators.py | 54 ++++++++++--------- pybamm/expression_tree/broadcasts.py | 4 +- pybamm/expression_tree/concatenations.py | 12 ++--- pybamm/expression_tree/functions.py | 8 +-- .../expression_tree/independent_variable.py | 10 ++-- pybamm/expression_tree/input_parameter.py | 2 +- pybamm/expression_tree/interpolant.py | 2 +- pybamm/expression_tree/operations/jacobian.py | 11 ++-- pybamm/expression_tree/state_vector.py | 12 ++--- pybamm/expression_tree/symbol.py | 10 ++-- pybamm/expression_tree/unary_operators.py | 10 +++- pybamm/expression_tree/variable.py | 4 +- 14 files changed, 90 insertions(+), 67 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 698d72a568..2ce51f1a32 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -5,7 +5,7 @@ import numpy as np import sympy from scipy.sparse import csr_matrix, issparse -from typing import Union, Tuple, Optional +from typing import Union, Tuple, Optional, Any import pybamm @@ -155,6 +155,6 @@ def meshgrid( see the `numpy meshgrid documentation `_ """ [X, Y] = np.meshgrid(x.entries, y.entries) - X = pybamm.Array(X) - Y = pybamm.Array(Y) - return X, Y + X = pybamm.Array(X) # type:ignore[assignment] + Y = pybamm.Array(Y) # type:ignore[assignment] + return X, Y # type:ignore[return-value] diff --git a/pybamm/expression_tree/averages.py b/pybamm/expression_tree/averages.py index 95b58ab48b..3d9510cda2 100644 --- a/pybamm/expression_tree/averages.py +++ b/pybamm/expression_tree/averages.py @@ -1,7 +1,8 @@ # # Classes and methods for averaging # -from typing import Union, Callable +from __future__ import annotations +from typing import Union, Callable, Sequence import pybamm @@ -16,7 +17,12 @@ class _BaseAverage(pybamm.Integral): """ def __init__( - self, child: pybamm.Symbol, name: str, integration_variable: list + self, + child: pybamm.Symbol, + name: str, + integration_variable: Union[ + Sequence[pybamm.IndependentVariable], pybamm.IndependentVariable + ], ) -> None: super().__init__(child, integration_variable) self.name = name diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 79b512b33a..10e29d671f 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -218,7 +218,7 @@ def _binary_evaluate( """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" # don't raise RuntimeWarning for NaNs with np.errstate(invalid="ignore"): - return left**right + return left**right # type:ignore[operator] class Addition(BinaryOperator): @@ -1151,7 +1151,7 @@ def divide( # Move constant to always be on the left # For a division, this means (var / constant) becomes (1/constant * var) if right.is_constant() and not left.is_constant(): - return (1 / right) * left + return (1 / right) * left # type:ignore # Check for Concatenations and Broadcasts out = _simplified_binary_broadcast_concatenation(left, right, divide) @@ -1277,7 +1277,7 @@ def matmul( def minimum( left: Union[numbers.Number, pybamm.Symbol], right: Union[numbers.Number, pybamm.Symbol], -): +) -> pybamm.Symbol: """ Returns the smaller of two objects, possibly with a smoothing approximation. Not to be confused with :meth:`pybamm.min`, which returns min function of child. @@ -1292,10 +1292,10 @@ def minimum( # Return exact approximation if that is the setting or the outcome is a constant # (i.e. no need for smoothing) if k == "exact" or (left.is_constant() and right.is_constant()): - out = Minimum(left, right) + out = Minimum(left, right) # type:ignore else: - out = pybamm.softminus(left, right, k) - return pybamm.simplify_if_constant(out) + out = pybamm.softminus(left, right, k) # type:ignore + return pybamm.simplify_if_constant(out) # type:ignore def maximum( @@ -1316,10 +1316,10 @@ def maximum( # Return exact approximation if that is the setting or the outcome is a constant # (i.e. no need for smoothing) if k == "exact" or (left.is_constant() and right.is_constant()): - out = Maximum(left, right) + out = Maximum(left, right) # type:ignore else: - out = pybamm.softplus(left, right, k) - return pybamm.simplify_if_constant(out) + out = pybamm.softplus(left, right, k) # type:ignore + return pybamm.simplify_if_constant(out) # type:ignore def _heaviside( @@ -1357,41 +1357,43 @@ def _heaviside( # (i.e. no need for smoothing) if k == "exact" or (left.is_constant() and right.is_constant()): if equal is True: - out = pybamm.EqualHeaviside(left, right) + out = pybamm.EqualHeaviside(left, right) # type:ignore else: - out = pybamm.NotEqualHeaviside(left, right) + out = pybamm.NotEqualHeaviside(left, right) # type:ignore else: - out = pybamm.sigmoid(left, right, k) - return pybamm.simplify_if_constant(out) + out = pybamm.sigmoid(left, right, k) # type:ignore + return pybamm.simplify_if_constant(out) # type:ignore def softminus( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: pybamm.Symbol, + right: pybamm.Symbol, k: float, -): +) -> pybamm.Symbol: """ Softplus approximation to the minimum function. k is the smoothing parameter, set by `pybamm.settings.min_smoothing`. The recommended value is k=10. """ - return pybamm.log(pybamm.exp(-k * left) + pybamm.exp(-k * right)) / -k + return ( + pybamm.log(pybamm.exp(-k * left) + pybamm.exp(-k * right)) / -k # type:ignore + ) def softplus( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: pybamm.Symbol, + right: pybamm.Symbol, k: float, ): """ Softplus approximation to the maximum function. k is the smoothing parameter, set by `pybamm.settings.max_smoothing`. The recommended value is k=10. """ - return pybamm.log(pybamm.exp(k * left) + pybamm.exp(k * right)) / k + return pybamm.log(pybamm.exp(k * left) + pybamm.exp(k * right)) / k # type:ignore def sigmoid( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: pybamm.Symbol, + right: pybamm.Symbol, k: float, ): """ @@ -1400,7 +1402,7 @@ def sigmoid( Note that the concept of deciding which side to pick when left=right does not apply for this smooth approximation. When left=right, the value is (left+right)/2. """ - return (1 + pybamm.tanh(k * (right - left))) / 2 + return (1 + pybamm.tanh(k * (right - left))) / 2 # type:ignore def source( @@ -1438,11 +1440,13 @@ def source( if isinstance(left, numbers.Number): left = pybamm.PrimaryBroadcast(left, "current collector") - if left.domain != ["current collector"] or right.domain != ["current collector"]: + if left.domain != ["current collector"] or right.domain != [ # type:ignore + "current collector" + ]: raise pybamm.DomainError( """'source' only implemented in the 'current collector' domain, but symbols have domains {} and {}""".format( - left.domain, right.domain + left.domain, right.domain # type:ignore ) ) if boundary: diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index ea3df9654a..2f148849fd 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -548,7 +548,7 @@ def reduce_one_dimension(self): ) -def full_like(symbols: pybamm.Symbol, fill_value: float): +def full_like(symbols: Sequence[pybamm.Symbol], fill_value: float) -> pybamm.Symbol: """ Returns an array with the same shape and domains as the sum of the input symbols, with a constant value given by `fill_value`. @@ -575,7 +575,7 @@ def full_like(symbols: pybamm.Symbol, fill_value: float): if shape[1] == 1: array_type = pybamm.Vector else: - array_type = pybamm.Matrix + array_type = pybamm.Matrix # type:ignore[assignment] # return dense array, except for a matrix of zeros if shape[1] != 1 and fill_value == 0: entries = csr_matrix(shape) diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 693f9c8386..b126567932 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -25,8 +25,8 @@ class Concatenation(pybamm.Symbol): def __init__( self, - *children: Iterable[pybamm.Symbol], - name=None, + *children: Sequence[pybamm.Symbol], + name: Optional[str] = None, check_domain=True, concat_fun=None, ): @@ -44,7 +44,7 @@ def __init__( if name is None: name = "concatenation" if check_domain: - domains = self.get_children_domains(children) + domains = self.get_children_domains(children) # type:ignore[arg-type] else: domains = {"primary": []} self.concatenation_function = concat_fun @@ -183,12 +183,12 @@ class NumpyConcatenation(Concatenation): """ def __init__(self, *children: Sequence[pybamm.Symbol]): - children = list(children) + children = list(children) # type:ignore[assignment] # Turn objects that evaluate to scalars to objects that evaluate to vectors, # so that we can concatenate them for i, child in enumerate(children): if child.evaluates_to_number(): # type:ignore - children[i] = child * pybamm.Vector([1]) + children[i] = child * pybamm.Vector([1]) # type:ignore super().__init__( *children, name="numpy_concatenation", @@ -234,7 +234,7 @@ class DomainConcatenation(Concatenation): def __init__( self, - children: Iterable[pybamm.Symbol], + children: Sequence[pybamm.Symbol], full_mesh: pybamm.Mesh, copy_this: Optional[pybamm.DomainConcatenation] = None, ): diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 4c67e16ba0..124b4c590e 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -87,9 +87,9 @@ def diff(self, variable: pybamm.Symbol): # remove None entries partial_derivatives = [x for x in partial_derivatives if x is not None] - derivative = sum(partial_derivatives) + derivative = sum(partial_derivatives) # type:ignore[arg-type] if derivative == 0: - derivative = pybamm.Scalar(0) + derivative = pybamm.Scalar(0) # type:ignore[assignment] return derivative @@ -106,7 +106,7 @@ def _function_diff(self, children: pybamm.Symbol, idx: float): differentiated_function=self.function, ) elif self.derivative == "derivative": - if len(children) > 1: + if len(children) > 1: # type:ignore[arg-type] raise ValueError( """ differentiation using '.derivative()' not implemented for functions @@ -194,7 +194,7 @@ def _function_new_copy(self, children: list) -> Function: : :pybamm.Function A new copy of the function """ - return pybamm.simplify_if_constant( + return pybamm.simplify_if_constant( # type:ignore[return-value] pybamm.Function( self.function, *children, diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 9bb7684b82..c57c153326 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -117,7 +117,7 @@ class SpatialVariable(IndependentVariable): def __init__( self, name: str, - domain: Union[list[str], str, None] = None, + domain: Optional[Union[list[str], str]] = None, auxiliary_domains: Optional[dict] = None, domains: Optional[dict] = None, coord_sys=None, @@ -132,20 +132,22 @@ def __init__( raise ValueError("domain must be provided") # Check symbol name vs domain name - if name == "r_n" and not all(n in domain[0] for n in ["negative", "particle"]): + if name == "r_n" and not all( + n in domain[0] for n in ["negative", "particle"] # type:ignore[index] + ): # catches "negative particle", "negative secondary particle", etc raise pybamm.DomainError( "domain must be negative particle if name is 'r_n'" ) elif name == "r_p" and not all( - n in domain[0] for n in ["positive", "particle"] + n in domain[0] for n in ["positive", "particle"] # type:ignore[index] ): # catches "positive particle", "positive secondary particle", etc raise pybamm.DomainError( "domain must be positive particle if name is 'r_p'" ) elif name in ["x", "y", "z", "x_n", "x_s", "x_p"] and any( - ["particle" in dom for dom in domain] + ["particle" in dom for dom in domain] # type:ignore[index, union-attr] ): raise pybamm.DomainError( "domain cannot be particle if name is '{}'".format(name) diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index ab60892ad0..f45e9b44cd 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -62,7 +62,7 @@ def _evaluate_for_shape(self): else: return np.nan * np.ones((self._expected_size, 1)) - def _jac(self, variable: pybamm.Variable) -> pybamm.Matrix: + def _jac(self, variable: pybamm.StateVector) -> pybamm.Matrix: """See :meth:`pybamm.Symbol._jac()`.""" n_variable = variable.evaluation_array.count(True) nan_vector = self._evaluate_for_shape() diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 8d4eee1bde..4c9833ef05 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -114,7 +114,7 @@ def __init__( children = [children] # Either a single x is provided and there is one child # or x is a 2-tuple and there are two children - if len(x) != len(children): + if len(x) != len(children): # type:ignore[arg-type] raise ValueError("len(x) should equal len(children)") # if there is only one x, y can be 2-dimensional but the child must have # length 1 diff --git a/pybamm/expression_tree/operations/jacobian.py b/pybamm/expression_tree/operations/jacobian.py index c3de08c21c..6eeca38838 100644 --- a/pybamm/expression_tree/operations/jacobian.py +++ b/pybamm/expression_tree/operations/jacobian.py @@ -21,7 +21,7 @@ class Jacobian(object): def __init__( self, - known_jacs: Optional[dict[str, pybamm.Symbol]] = None, + known_jacs: Optional[dict[pybamm.Symbol, pybamm.Symbol]] = None, clear_domain: bool = True, ): self._known_jacs = known_jacs or {} @@ -76,12 +76,17 @@ def _jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol): elif isinstance(symbol, pybamm.Function): children_jacs = [None] * len(symbol.children) for i, child in enumerate(symbol.children): - children_jacs[i] = self.jac(child, variable) + children_jacs[i] = self.jac( # type:ignore[call-overload] + child, variable + ) # _function_jac defined in function class jac = symbol._function_jac(children_jacs) elif isinstance(symbol, pybamm.Concatenation): - children_jacs = [self.jac(child, variable) for child in symbol.children] + children_jacs = [ + self.jac(child, variable) # type:ignore[misc] + for child in symbol.children + ] if len(children_jacs) == 1: jac = children_jacs[0] else: diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 897474aa6d..2a9c5230e8 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -40,8 +40,8 @@ def __init__( base_name="y", name: Optional[str] = None, domain: Optional[Iterable[str]] = None, - auxiliary_domains: Optional[dict[str]] = None, - domains: Optional[dict] = None, + auxiliary_domains: Optional[dict] = None, + domains: Optional[dict[str, list[str]]] = None, evaluation_array: Optional[list] = None, ): for y_slice in y_slices: @@ -227,8 +227,8 @@ def __init__( *y_slices: slice, name: Optional[str] = None, domain: Optional[Iterable[str]] = None, - auxiliary_domains: Optional[dict[str]] = None, - domains: Optional[dict] = None, + auxiliary_domains: Optional[dict] = None, + domains: Optional[dict[str, list[str]]] = None, evaluation_array: Optional[list] = None, ): super().__init__( @@ -311,8 +311,8 @@ def __init__( *y_slices: slice, name: Optional[str] = None, domain: Optional[Iterable[str]] = None, - auxiliary_domains: Optional[dict[str]] = None, - domains: Optional[dict] = None, + auxiliary_domains: Optional[dict] = None, + domains: Optional[dict[str, list[str]]] = None, evaluation_array: Optional[list] = None, ): super().__init__( diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 024c81e119..02459340e6 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -217,7 +217,7 @@ def __init__( self, name: str, children: Optional[list[Symbol]] = None, - domain: Optional[Union[list[str], str]] = None, + domain: Optional[Union[Sequence[str], str]] = None, auxiliary_domains: Optional[dict[str, str]] = None, domains: Optional[dict] = None, ): @@ -296,7 +296,7 @@ def auxiliary_domains(self): ) @domains.setter - def domains(self, domains): + def domains(self, domains): # type:ignore try: if ( self._domains == domains @@ -367,7 +367,7 @@ def clear_domains(self): self._domains = EMPTY_DOMAINS self.set_id() - def get_children_domains(self, children: Iterable[Symbol]): + def get_children_domains(self, children: Sequence[Symbol]): """Combine domains from children, at all levels.""" domains: dict = {} for child in children: @@ -390,7 +390,7 @@ def get_children_domains(self, children: Iterable[Symbol]): def read_domain_or_domains( self, - domain: Optional[Union[list[str], str]], + domain: Optional[Union[Sequence[str], str]], auxiliary_domains: Optional[dict[str, str]], domains: Optional[dict], ): @@ -762,7 +762,7 @@ def _base_evaluate( def evaluate( self, - t: Optional[numbers.Number] = None, + t: Optional[float] = None, y: Optional[np.ndarray] = None, y_dot: Optional[np.ndarray] = None, inputs: Optional[dict] = None, diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index f493e55211..5f1ed5bee7 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -3,7 +3,7 @@ # from __future__ import annotations import numbers -from typing import Optional +from typing import Optional, Union, Sequence import numpy as np import sympy @@ -489,7 +489,13 @@ class Integral(SpatialOperator): The variable over which to integrate """ - def __init__(self, child, integration_variable): + def __init__( + self, + child, + integration_variable: Union[ + Sequence[pybamm.IndependentVariable], pybamm.IndependentVariable + ], + ): if not isinstance(integration_variable, list): integration_variable = [integration_variable] diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index d65358a97a..57a03142d0 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -51,7 +51,7 @@ class VariableBase(pybamm.Symbol): def __init__( self, name: str, - domain: Optional[Iterable[str]] = None, + domain: Optional[Union[Sequence[str], str]] = None, auxiliary_domains: Optional[dict] = None, domains: Optional[dict] = None, bounds: Optional[tuple] = None, @@ -89,7 +89,7 @@ def bounds(self, values: Sequence[numbers.Number]): else: if ( all(isinstance(b, numbers.Number) for b in values) - and values[0] >= values[1] + and values[0] >= values[1] # type:ignore ): raise ValueError( f"Invalid bounds {values}. " From 8854bdcd24501086ca21bd60552771a12e9a410b Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 8 Jun 2023 15:57:18 +0000 Subject: [PATCH 04/32] (wip) More mypy changes --- pybamm/expression_tree/concatenations.py | 4 ++-- pybamm/expression_tree/interpolant.py | 11 ++++++----- .../operations/evaluate_python.py | 7 +++++-- .../operations/unpack_symbols.py | 2 +- pybamm/expression_tree/parameter.py | 10 +++++----- pybamm/expression_tree/state_vector.py | 8 ++++---- pybamm/expression_tree/symbol.py | 18 ++++++++++-------- pybamm/parameters/bpx.py | 2 +- 8 files changed, 34 insertions(+), 28 deletions(-) diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index b126567932..698fb061bf 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -49,7 +49,7 @@ def __init__( domains = {"primary": []} self.concatenation_function = concat_fun - super().__init__(name, children, domains=domains) + super().__init__(name, children, domains=domains) # type:ignore[arg-type] def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" @@ -71,7 +71,7 @@ def _diff(self, variable: pybamm.Symbol): def get_children_domains(self, children: Sequence[pybamm.Symbol]): # combine domains from children - domain = [] + domain: list = [] for child in children: if not isinstance(child, pybamm.Symbol): raise TypeError("{} is not a pybamm symbol".format(child)) diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 4c9833ef05..6d09382b80 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -1,9 +1,10 @@ # # Interpolating class # +from __future__ import annotations import numpy as np from scipy import interpolate -from typing import Iterable, Optional +from typing import Optional, Sequence, Union import warnings import pybamm @@ -42,9 +43,9 @@ class Interpolant(pybamm.Function): def __init__( self, - x: np.ndarray, + x: Sequence[np.ndarray], y: np.ndarray, - children: Iterable[pybamm.Symbol], + children: Union[Sequence[pybamm.Symbol], pybamm.Time], name: Optional[str] = None, interpolator: Optional[str] = "linear", extrapolate: Optional[bool] = True, @@ -130,7 +131,7 @@ def __init__( if extrapolate is False: fill_value = np.nan elif extrapolate is True: - fill_value = "extrapolate" + fill_value = "extrapolate" # type:ignore[assignment] interpolating_function = interpolate.interp1d( x1, y.T, @@ -180,7 +181,7 @@ def __init__( ) else: interpolating_function = interpolate.RegularGridInterpolator( - (x1, x2, x3), + (x1, x2, x3), # type:ignore[has-type] y, method=interpolator, bounds_error=False, diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 00ffeb3e23..1017e28e68 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -71,7 +71,9 @@ def dot_product(self, b: jaxlib.xla_extension.DeviceArray): """ # assume b is a column vector result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype) - return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col]) + return result.at[self.row].add( + self.data.reshape(-1, 1) * b[self.col] # type:ignore[index] + ) def scalar_multiply(self, b: float): """ @@ -612,7 +614,8 @@ def __init__(self, symbol: pybamm.Symbol): self._static_argnums = tuple(static_argnums) self._jit_evaluate = jax.jit( - self._evaluate_jax, static_argnums=self._static_argnums + self._evaluate_jax, # type:ignore[attr-defined] + static_argnums=self._static_argnums, ) def get_jacobian(self): diff --git a/pybamm/expression_tree/operations/unpack_symbols.py b/pybamm/expression_tree/operations/unpack_symbols.py index 61086e8a04..0e22081caa 100644 --- a/pybamm/expression_tree/operations/unpack_symbols.py +++ b/pybamm/expression_tree/operations/unpack_symbols.py @@ -27,7 +27,7 @@ def __init__( unpacked_symbols: Optional[set] = None, ): self.classes_to_find = classes_to_find - self._unpacked_symbols: set = unpacked_symbols or {} + self._unpacked_symbols: Union[set, dict] = unpacked_symbols or {} def unpack_list_of_symbols( self, list_of_symbols: Sequence[pybamm.Symbol] diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index 8cb60aec36..e4e5914263 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -35,7 +35,7 @@ def create_copy(self) -> pybamm.Parameter: """See :meth:`pybamm.Symbol.new_copy()`.""" return Parameter(self.name) - def _evaluate_for_shape(self) -> np.nan: + def _evaluate_for_shape(self) -> float: """ Returns the scalar 'NaN' to represent the shape of a parameter. See :meth:`pybamm.Symbol.evaluate_for_shape()` @@ -101,7 +101,7 @@ def __init__( domains = self.get_children_domains(children_list) super().__init__(name, children=children_list, domains=domains) - self.input_names = list(inputs.keys()) + self.input_names = list(inputs.keys()) # type:ignore[misc] # Use the inspect module to find the function's "short name" from the # Parameters module that called it @@ -109,12 +109,12 @@ def __init__( self.print_name = print_name else: frame = sys._getframe().f_back - print_name = frame.f_code.co_name + print_name = frame.f_code.co_name # type:ignore[union-attr] if print_name.startswith("_"): self.print_name = None else: try: - parent_param = frame.f_locals["self"] + parent_param = frame.f_locals["self"] # type:ignore[union-attr] except KeyError: parent_param = None if hasattr(parent_param, "domain") and parent_param.domain is not None: @@ -133,7 +133,7 @@ def print_input_names(self): for inp in self._input_names: print(inp) - @input_names.setter + @input_names.setter # type:ignore[no-redef, attr-defined] def input_names(self, inp=None): if inp: if inp.__class__ is list: diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 2a9c5230e8..11eaa104c2 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -4,7 +4,7 @@ from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, vstack -from typing import Optional, Iterable, Union +from typing import Optional, Iterable, Union, Sequence import pybamm @@ -39,7 +39,7 @@ def __init__( *y_slices: slice, base_name="y", name: Optional[str] = None, - domain: Optional[Iterable[str]] = None, + domain: Optional[Sequence[str]] = None, auxiliary_domains: Optional[dict] = None, domains: Optional[dict[str, list[str]]] = None, evaluation_array: Optional[list] = None, @@ -226,7 +226,7 @@ def __init__( self, *y_slices: slice, name: Optional[str] = None, - domain: Optional[Iterable[str]] = None, + domain: Optional[Sequence[str]] = None, auxiliary_domains: Optional[dict] = None, domains: Optional[dict[str, list[str]]] = None, evaluation_array: Optional[list] = None, @@ -310,7 +310,7 @@ def __init__( self, *y_slices: slice, name: Optional[str] = None, - domain: Optional[Iterable[str]] = None, + domain: Optional[Sequence[str]] = None, auxiliary_domains: Optional[dict] = None, domains: Optional[dict[str, list[str]]] = None, evaluation_array: Optional[list] = None, diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 02459340e6..e0fcf5fe36 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -10,7 +10,7 @@ from anytree.exporter import DotExporter from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property -from typing import Union, TYPE_CHECKING, Optional, Iterable, Sequence +from typing import Union, TYPE_CHECKING, Optional, Iterable, Sequence, TypeVar import pybamm from pybamm.expression_tree.printing.print_name import prettify_print_name @@ -24,6 +24,8 @@ ) import casadi + S = TypeVar("S", bound=pybamm.Symbol) + DOMAIN_LEVELS = ["primary", "secondary", "tertiary", "quaternary"] EMPTY_DOMAINS = {k: [] for k in DOMAIN_LEVELS} @@ -157,8 +159,8 @@ def is_matrix_minus_one(expr: Symbol): def simplify_if_constant( - symbol: Symbol, -) -> Symbol: + symbol: S, +) -> S: """ Utility function to simplify an expression tree if it evalutes to a constant scalar, vector or matrix @@ -216,7 +218,7 @@ class Symbol: def __init__( self, name: str, - children: Optional[list[Symbol]] = None, + children: Optional[Sequence[Symbol]] = None, domain: Optional[Union[Sequence[str], str]] = None, auxiliary_domains: Optional[dict[str, str]] = None, domains: Optional[dict] = None, @@ -295,8 +297,8 @@ def auxiliary_domains(self): "symbol.auxiliary_domains has been deprecated, use symbol.domains instead" ) - @domains.setter - def domains(self, domains): # type:ignore + @domains.setter # type:ignore[no-redef, attr-defined] + def domains(self, domains): try: if ( self._domains == domains @@ -596,7 +598,7 @@ def __rtruediv__(self, other: Symbol) -> Division: """return a :class:`Division` object.""" return pybamm.divide(other, self) - def __pow__(self, other: Union[Symbol, float]) -> pybamm.Power: + def __pow__(self, other: Union[Symbol, numbers.Number]) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(self, other) @@ -604,7 +606,7 @@ def __rpow__(self, other: Symbol) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(other, self) - def __lt__(self, other: Union[Symbol, float]) -> pybamm.NotEqualHeaviside: + def __lt__(self, other: Union[Symbol, numbers.Number]) -> pybamm.NotEqualHeaviside: """return a :class:`NotEqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(self, other, False) diff --git a/pybamm/parameters/bpx.py b/pybamm/parameters/bpx.py index cb10a87ede..3fb1ddbbbc 100644 --- a/pybamm/parameters/bpx.py +++ b/pybamm/parameters/bpx.py @@ -61,7 +61,7 @@ class Domain: def _bpx_to_param_dict(bpx: BPX) -> dict: - pybamm_dict = {} + pybamm_dict: dict = {} pybamm_dict = _bpx_to_domain_param_dict( bpx.parameterisation.cell, pybamm_dict, cell ) From e41a9248e4c895d332fcb5d281a1e20c69ef3404 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 8 Jun 2023 20:00:34 +0000 Subject: [PATCH 05/32] (wip): Passing mypy with --allow-redefinition --- pybamm/expression_tree/averages.py | 14 ++-- pybamm/expression_tree/binary_operators.py | 64 +++++++-------- pybamm/expression_tree/concatenations.py | 4 +- pybamm/expression_tree/functions.py | 4 +- .../expression_tree/independent_variable.py | 8 +- pybamm/expression_tree/input_parameter.py | 4 +- pybamm/expression_tree/interpolant.py | 2 +- .../operations/convert_to_casadi.py | 3 +- .../operations/unpack_symbols.py | 8 +- pybamm/expression_tree/scalar.py | 2 +- pybamm/expression_tree/symbol.py | 78 ++++++++++++++++--- pybamm/expression_tree/unary_operators.py | 2 +- pybamm/expression_tree/variable.py | 6 +- pybamm/solvers/idaklu_solver.py | 1 + pybamm/solvers/jax_bdf_solver.py | 1 + pybamm/solvers/scikits_dae_solver.py | 1 + pybamm/solvers/scikits_ode_solver.py | 1 + 17 files changed, 139 insertions(+), 64 deletions(-) diff --git a/pybamm/expression_tree/averages.py b/pybamm/expression_tree/averages.py index 3d9510cda2..c209442adf 100644 --- a/pybamm/expression_tree/averages.py +++ b/pybamm/expression_tree/averages.py @@ -21,7 +21,7 @@ def __init__( child: pybamm.Symbol, name: str, integration_variable: Union[ - Sequence[pybamm.IndependentVariable], pybamm.IndependentVariable + list[pybamm.IndependentVariable], pybamm.IndependentVariable ], ) -> None: super().__init__(child, integration_variable) @@ -48,7 +48,7 @@ class YZAverage(_BaseAverage): def __init__(self, child: pybamm.Symbol) -> None: y = pybamm.standard_spatial_vars.y z = pybamm.standard_spatial_vars.z - integration_variable = [y, z] + integration_variable: list[pybamm.IndependentVariable] = [y, z] super().__init__(child, "yz-average", integration_variable) def _unary_new_copy(self, child: pybamm.Symbol): @@ -58,7 +58,9 @@ def _unary_new_copy(self, child: pybamm.Symbol): class ZAverage(_BaseAverage): def __init__(self, child: pybamm.Symbol) -> None: - integration_variable = [pybamm.standard_spatial_vars.z] + integration_variable: list[pybamm.IndependentVariable] = [ + pybamm.standard_spatial_vars.z + ] super().__init__(child, "z-average", integration_variable) def _unary_new_copy(self, child: pybamm.Symbol): @@ -68,7 +70,9 @@ def _unary_new_copy(self, child: pybamm.Symbol): class RAverage(_BaseAverage): def __init__(self, child: pybamm.Symbol) -> None: - integration_variable = [pybamm.SpatialVariable("r", child.domain)] + integration_variable: list[pybamm.IndependentVariable] = [ + pybamm.SpatialVariable("r", child.domain) + ] super().__init__(child, "r-average", integration_variable) def _unary_new_copy(self, child: pybamm.Symbol): @@ -79,7 +83,7 @@ def _unary_new_copy(self, child: pybamm.Symbol): class SizeAverage(_BaseAverage): def __init__(self, child: pybamm.Symbol, f_a_dist) -> None: R = pybamm.SpatialVariable("R", domains=child.domains, coord_sys="cartesian") - integration_variable = [R] + integration_variable: list[pybamm.IndependentVariable] = [R] super().__init__(child, "size-average", integration_variable) self.f_a_dist = f_a_dist diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 10e29d671f..b33f18bba0 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -14,8 +14,8 @@ def _preprocess_binary( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ) -> Tuple[pybamm.Symbol, pybamm.Symbol]: if isinstance(left, numbers.Number): left = pybamm.Scalar(left) @@ -68,8 +68,8 @@ class BinaryOperator(pybamm.Symbol): def __init__( self, name: str, - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ) -> None: left, right = _preprocess_binary(left, right) @@ -113,8 +113,8 @@ def create_copy(self): def _binary_new_copy( self, - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): """ Default behaviour for new_copy. @@ -441,8 +441,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.inner(left, right) @@ -506,8 +506,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.Equality(left, right) @@ -660,8 +660,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.minimum(left, right) @@ -700,8 +700,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.maximum(left, right) @@ -712,8 +712,8 @@ def _sympy_operator(self, left, right): def _simplify_elementwise_binary_broadcasts( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ) -> Tuple[pybamm.Symbol, pybamm.Symbol]: left, right = _preprocess_binary(left, right) @@ -788,8 +788,8 @@ def _simplified_binary_broadcast_concatenation( def simplified_power( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -924,8 +924,8 @@ def add( def subtract( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): """ Note @@ -1009,8 +1009,8 @@ def subtract( def multiply( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -1139,8 +1139,8 @@ def multiply( def divide( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -1213,8 +1213,8 @@ def divide( def matmul( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): left, right = _preprocess_binary(left, right) if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): @@ -1275,8 +1275,8 @@ def matmul( def minimum( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ) -> pybamm.Symbol: """ Returns the smaller of two objects, possibly with a smoothing approximation. @@ -1299,8 +1299,8 @@ def minimum( def maximum( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], ): """ Returns the larger of two objects, possibly with a smoothing approximation. @@ -1323,8 +1323,8 @@ def maximum( def _heaviside( - left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + left: Union[float, pybamm.Symbol], + right: Union[float, pybamm.Symbol], equal, ): """return a :class:`EqualHeaviside` object, or a smooth approximation.""" diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 698fb061bf..6eb1ebf868 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -243,7 +243,9 @@ def __init__( children = list(children) # Allow the base class to sort the domains into the correct order - super().__init__(*children, name="domain_concatenation") + super().__init__( + *children, name="domain_concatenation" # type:ignore[arg-type] + ) # type:ignore[arg-type] if copy_this is None: # store mesh diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 124b4c590e..a40e0bf829 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -8,7 +8,7 @@ import numpy as np import sympy from scipy import special -from typing import Optional +from typing import Optional, Sequence import pybamm @@ -93,7 +93,7 @@ def diff(self, variable: pybamm.Symbol): return derivative - def _function_diff(self, children: pybamm.Symbol, idx: float): + def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float): """ Derivative with respect to child number 'idx'. See :meth:`pybamm.Symbol._diff()`. diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index c57c153326..510e19a2b2 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -3,7 +3,7 @@ # import sympy import numpy as np -from typing import Union, Optional +from typing import Union, Optional, Any import pybamm @@ -34,7 +34,7 @@ class IndependentVariable(pybamm.Symbol): def __init__( self, name: str, - domain: Optional[list[str]] = None, + domain: Optional[Union[list[str], str]] = None, auxiliary_domains: Optional[dict] = None, domains: Optional[dict] = None, ) -> None: @@ -65,6 +65,7 @@ class Time(IndependentVariable): def __init__(self): super().__init__("time") + # making this super(pybamm.Symbol, self)__init__(name="time") works, but not sure why. def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" @@ -114,6 +115,8 @@ class SpatialVariable(IndependentVariable): deprecated. """ + # coord_sys: Optional[Any] + def __init__( self, name: str, @@ -127,6 +130,7 @@ def __init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains ) domain = self.domain + # using a dataclass, at this point the domain doesn't get set for some reason, during initialisation. if domain == []: raise ValueError("domain must be provided") diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index f45e9b44cd..f1d2795fdc 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -7,7 +7,7 @@ import scipy.sparse import pybamm -from typing import Union, Iterable, Optional +from typing import Union, Iterable, Optional, Sequence class InputParameter(pybamm.Symbol): @@ -31,7 +31,7 @@ class InputParameter(pybamm.Symbol): def __init__( self, name: str, - domain: Optional[Union[Iterable[str], str]] = None, + domain: Optional[Union[Sequence[str], str]] = None, expected_size: Optional[int] = None, ) -> None: # Expected size defaults to 1 if no domain else None (gets set later) diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 6d09382b80..be14b888b8 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -103,7 +103,7 @@ def __init__( x1 = x[0] else: x1 = x - x = [x] + x = [x] # type:ignore[list-item] x2 = None if x1.shape[0] != y.shape[0]: raise ValueError( diff --git a/pybamm/expression_tree/operations/convert_to_casadi.py b/pybamm/expression_tree/operations/convert_to_casadi.py index 91202fc18e..3caf18bb4d 100644 --- a/pybamm/expression_tree/operations/convert_to_casadi.py +++ b/pybamm/expression_tree/operations/convert_to_casadi.py @@ -5,6 +5,7 @@ import casadi import numpy as np from scipy import special +from typing import Union class CasadiConverter(object): @@ -19,7 +20,7 @@ def convert( t: casadi.MX, y: casadi.MX, y_dot: casadi.MX, - inputs: dict, + inputs: Union[dict, None], ) -> casadi.MX: """ This function recurses down the tree, converting the PyBaMM expression tree to diff --git a/pybamm/expression_tree/operations/unpack_symbols.py b/pybamm/expression_tree/operations/unpack_symbols.py index 0e22081caa..5ebfbf4ef7 100644 --- a/pybamm/expression_tree/operations/unpack_symbols.py +++ b/pybamm/expression_tree/operations/unpack_symbols.py @@ -24,10 +24,10 @@ class SymbolUnpacker(object): def __init__( self, classes_to_find: Union[pybamm.Symbol, Sequence[pybamm.Symbol]], - unpacked_symbols: Optional[set] = None, + unpacked_symbols: Optional[dict] = None, ): self.classes_to_find = classes_to_find - self._unpacked_symbols: Union[set, dict] = unpacked_symbols or {} + self._unpacked_symbols: dict = unpacked_symbols or {} # type:ignore[assignment] def unpack_list_of_symbols( self, list_of_symbols: Sequence[pybamm.Symbol] @@ -52,7 +52,9 @@ def unpack_list_of_symbols( return all_instances - def unpack_symbol(self, symbol: Sequence[pybamm.Symbol]) -> list[pybamm.Symbol]: + def unpack_symbol( + self, symbol: Union[Sequence[pybamm.Symbol], pybamm.Symbol] + ) -> list[pybamm.Symbol]: """ This function recurses down the tree, unpacking the symbols and saving the ones that have a class in `self.classes_to_find`. diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index 6e69745743..efd726adbb 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -26,7 +26,7 @@ class Scalar(pybamm.Symbol): def __init__( self, - value: Union[pybamm.Scalar, numbers.Number, float], + value: Union[float, numbers.Number], name: Optional[str] = None, ) -> None: # set default name if not provided diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index e0fcf5fe36..80f0f5f4ed 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -3,6 +3,7 @@ # from __future__ import annotations import numbers +import inspect import anytree import numpy as np @@ -10,7 +11,18 @@ from anytree.exporter import DotExporter from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property -from typing import Union, TYPE_CHECKING, Optional, Iterable, Sequence, TypeVar +from typing import ( + Union, + TYPE_CHECKING, + Optional, + Iterable, + Sequence, + TypeVar, + no_type_check, +) + +# from pydantic import BaseModel +# from pydantic.dataclasses import dataclass import pybamm from pybamm.expression_tree.printing.print_name import prettify_print_name @@ -27,7 +39,28 @@ S = TypeVar("S", bound=pybamm.Symbol) DOMAIN_LEVELS = ["primary", "secondary", "tertiary", "quaternary"] -EMPTY_DOMAINS = {k: [] for k in DOMAIN_LEVELS} +EMPTY_DOMAINS: dict[str, list] = {k: [] for k in DOMAIN_LEVELS} + + +# class PatchedModel(BaseModel): +# @no_type_check +# def __setattr__(self, name, value): +# """ +# To be able to use properties with setters +# """ +# try: +# super().__setattr__(name, value) +# except ValueError as e: +# setters = inspect.getmembers( +# self.__class__, +# predicate=lambda x: isinstance(x, property) and x.fset is not None, +# ) +# for setter_name, func in setters: +# if setter_name == name: +# object.__setattr__(self, name, value) +# break +# else: +# raise e def domain_size(domain: Union[list[str], str]): @@ -173,20 +206,25 @@ def simplify_if_constant( or (isinstance(result, np.ndarray) and result.ndim == 0) or isinstance(result, np.bool_) ): - return pybamm.Scalar(result) + return pybamm.Scalar(result) # type:ignore[return-value, arg-type] elif isinstance(result, np.ndarray) or issparse(result): if result.ndim == 1 or result.shape[1] == 1: - return pybamm.Vector(result, domains=symbol.domains) + return pybamm.Vector( # type:ignore[return-value] + result, domains=symbol.domains + ) else: # Turn matrix of zeros into sparse matrix if isinstance(result, np.ndarray) and np.all(result == 0): result = csr_matrix(result) - return pybamm.Matrix(result, domains=symbol.domains) + return pybamm.Matrix( # type:ignore[return-value] + result, domains=symbol.domains + ) return symbol -class Symbol: +# @dataclass +class Symbol: # PatchedModel """ Base node class for the expression tree. @@ -215,6 +253,8 @@ class Symbol: deprecated. """ + # name: str + def __init__( self, name: str, @@ -223,6 +263,13 @@ def __init__( auxiliary_domains: Optional[dict[str, str]] = None, domains: Optional[dict] = None, ): + # super().__init__( + # name=name, + # children=children, + # domain=domain, + # auxiliary_domains=auxiliary_domains, + # domains=domains, + # ) super(Symbol, self).__init__() self.name = name @@ -234,9 +281,9 @@ def __init__( self._orphans = children # Set domains (and hence id) - self.domains = self.read_domain_or_domains(domain, auxiliary_domains, domains) + self.domains = self.read_domain_or_domains(domain, auxiliary_domains, domains) # type: ignore[misc] - self._saved_evaluates_on_edges = {} + self._saved_evaluates_on_edges: dict = {} self._print_name = None # Test shape on everything but nodes that contain the base Symbol class or @@ -249,6 +296,15 @@ def __init__( ): self.test_shape() + # # super().__init__(name, children, domain, auxiliary_domains, domains) + + # class Config: + # arbitrary_types_allowed = True + # # underscore_attrs_are_private = True + # keep_untouched = (cached_property,) + # fields = {"domain": {"exclude": True}, "auxiliary_domains": {"exclude": True}} + # # json_encoders = {"Symbol": lambda u: u.__dict__} + @property def children(self): """ @@ -598,7 +654,7 @@ def __rtruediv__(self, other: Symbol) -> Division: """return a :class:`Division` object.""" return pybamm.divide(other, self) - def __pow__(self, other: Union[Symbol, numbers.Number]) -> pybamm.Power: + def __pow__(self, other: Union[Symbol, float]) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(self, other) @@ -606,7 +662,7 @@ def __rpow__(self, other: Symbol) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(other, self) - def __lt__(self, other: Union[Symbol, numbers.Number]) -> pybamm.NotEqualHeaviside: + def __lt__(self, other: Union[Symbol, float]) -> pybamm.NotEqualHeaviside: """return a :class:`NotEqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(self, other, False) @@ -707,7 +763,7 @@ def _diff(self, variable): def jac( self, variable: pybamm.Symbol, - known_jacs: Optional[dict[str, pybamm.Symbol]] = None, + known_jacs: Optional[dict[pybamm.Symbol, pybamm.Symbol]] = None, clear_domain=True, ): """ diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 5f1ed5bee7..c441645f10 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -493,7 +493,7 @@ def __init__( self, child, integration_variable: Union[ - Sequence[pybamm.IndependentVariable], pybamm.IndependentVariable + list[pybamm.IndependentVariable], pybamm.IndependentVariable ], ): if not isinstance(integration_variable, list): diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index 57a03142d0..ba373f8ec8 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -99,7 +99,7 @@ def bounds(self, values: Sequence[numbers.Number]): values = list(values) for idx, bound in enumerate(values): if isinstance(bound, numbers.Number): - values[idx] = pybamm.Scalar(bound) + values[idx] = pybamm.Scalar(bound) # type:ignore[call-overload] self._bounds = tuple(values) def set_id(self): @@ -234,7 +234,9 @@ def get_variable(self) -> pybamm.Variable: """ return Variable(self.name[:-1], domains=self.domains, scale=self.scale) - def diff(self, variable: pybamm.Variable) -> pybamm.Scalar: + def diff( + self, variable: pybamm.VariableDot # type:ignore[override] + ) -> pybamm.Scalar: if variable == self: return pybamm.Scalar(1) elif variable == pybamm.t: diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index 5ccff7ed14..f840c828c9 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -1,6 +1,7 @@ # # Solver class using sundials with the KLU sparse linear solver # +# mypy: ignore-errors import casadi import pybamm import numpy as np diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index b69744dd08..46419919c0 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors import collections import operator as op from functools import partial diff --git a/pybamm/solvers/scikits_dae_solver.py b/pybamm/solvers/scikits_dae_solver.py index 56b3ff42c3..813d2b4b95 100644 --- a/pybamm/solvers/scikits_dae_solver.py +++ b/pybamm/solvers/scikits_dae_solver.py @@ -1,6 +1,7 @@ # # Solver class using Scipy's adaptive time stepper # +# mypy: ignore-errors import casadi import pybamm diff --git a/pybamm/solvers/scikits_ode_solver.py b/pybamm/solvers/scikits_ode_solver.py index 66132f39bb..770fc7e02a 100644 --- a/pybamm/solvers/scikits_ode_solver.py +++ b/pybamm/solvers/scikits_ode_solver.py @@ -1,6 +1,7 @@ # # Solver class using Scipy's adaptive time stepper # +# mypy: ignore-errors import casadi import pybamm From 79aefb85ec05c636b8768ca0494ebe7fef4844ae Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Mon, 27 Nov 2023 16:04:18 +0000 Subject: [PATCH 06/32] Fixes, remove most type:ignores (not passing) --- .pre-commit-config.yaml | 2 +- pybamm/expression_tree/array.py | 28 +-- pybamm/expression_tree/averages.py | 10 +- pybamm/expression_tree/binary_operators.py | 233 +++++++++++------- pybamm/expression_tree/broadcasts.py | 63 ++--- pybamm/expression_tree/concatenations.py | 47 ++-- pybamm/expression_tree/functions.py | 60 +++-- .../expression_tree/independent_variable.py | 18 +- pybamm/expression_tree/input_parameter.py | 4 +- pybamm/expression_tree/interpolant.py | 8 +- pybamm/expression_tree/matrix.py | 4 +- pybamm/expression_tree/operations/jacobian.py | 9 +- .../operations/unpack_symbols.py | 2 +- pybamm/expression_tree/parameter.py | 16 +- pybamm/expression_tree/scalar.py | 12 +- pybamm/expression_tree/state_vector.py | 20 +- pybamm/expression_tree/symbol.py | 146 ++++------- pybamm/expression_tree/unary_operators.py | 4 +- pybamm/expression_tree/variable.py | 16 +- 19 files changed, 349 insertions(+), 353 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 288f139afa..89a96187b7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: rev: "v0.0.288" hooks: - id: ruff - args: [--fix, --ignore=E741, --exclude=__init__.py] + args: [--fix, --ignore=E741, --exclude=__init__.py, --select=I002] - repo: https://github.com/nbQA-dev/nbQA rev: 1.7.0 diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 2ce51f1a32..a1b38bdbcd 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -38,11 +38,11 @@ class Array(pybamm.Symbol): def __init__( self, - entries: Union[np.ndarray, list], + entries: Union[np.ndarray, list, csr_matrix], name: Optional[str] = None, domain: Union[list[str], str, None] = None, auxiliary_domains: Optional[dict[str, str]] = None, - domains: Optional[dict] = None, + domains: Optional[dict[str, list[str]]] = None, entries_string: Optional[str] = None, ) -> None: # if @@ -78,7 +78,7 @@ def entries_string(self): return self._entries_string @entries_string.setter - def entries_string(self, value): + def entries_string(self, value: Union[None, tuple]): # We must include the entries in the hash, since different arrays can be # indistinguishable by class, name and domain alone # Slightly different syntax for sparse and non-sparse matrices @@ -88,10 +88,10 @@ def entries_string(self, value): entries = self._entries if issparse(entries): dct = entries.__dict__ - self._entries_string = ["shape", str(dct["_shape"])] + entries_string = ["shape", str(dct["_shape"])] for key in ["data", "indices", "indptr"]: - self._entries_string += [key, dct[key].tobytes()] - self._entries_string = tuple(self._entries_string) + entries_string += [key, dct[key].tobytes()] + self._entries_string = tuple(entries_string) # self._entries_string = str(entries.__dict__) else: self._entries_string = (entries.tobytes(),) @@ -117,13 +117,7 @@ def create_copy(self): entries_string=self.entries_string, ) - def _base_evaluate( - self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, - ): + def _base_evaluate(self, t, y, y_dot, inputs): """See :meth:`pybamm.Symbol._base_evaluate()`.""" return self._entries @@ -137,7 +131,7 @@ def to_equation(self) -> sympy.Array: return sympy.Array(entries_list) -def linspace(start: float, stop: float, num=50, **kwargs) -> pybamm.Array: +def linspace(start: float, stop: float, num: int = 50, **kwargs) -> pybamm.Array: """ Creates a linearly spaced array by calling `numpy.linspace` with keyword arguments 'kwargs'. For a list of 'kwargs' see the @@ -155,6 +149,6 @@ def meshgrid( see the `numpy meshgrid documentation `_ """ [X, Y] = np.meshgrid(x.entries, y.entries) - X = pybamm.Array(X) # type:ignore[assignment] - Y = pybamm.Array(Y) # type:ignore[assignment] - return X, Y # type:ignore[return-value] + X = pybamm.Array(X) + Y = pybamm.Array(Y) + return X, Y diff --git a/pybamm/expression_tree/averages.py b/pybamm/expression_tree/averages.py index c209442adf..b85ab7cfd0 100644 --- a/pybamm/expression_tree/averages.py +++ b/pybamm/expression_tree/averages.py @@ -2,7 +2,7 @@ # Classes and methods for averaging # from __future__ import annotations -from typing import Union, Callable, Sequence +from typing import Union, Callable, Optional import pybamm @@ -211,7 +211,7 @@ def z_average(symbol: pybamm.Symbol) -> pybamm.Symbol: return symbol # If symbol is a Broadcast, its average value is its child elif isinstance(symbol, pybamm.Broadcast): - return symbol.reduce_one_dimension() # type:ignore + return symbol.reduce_one_dimension() # Average of a sum is sum of averages elif isinstance(symbol, (pybamm.Addition, pybamm.Subtraction)): return _sum_of_averages(symbol, z_average) @@ -247,7 +247,7 @@ def yz_average(symbol: pybamm.Symbol) -> pybamm.Symbol: return symbol # If symbol is a Broadcast, its average value is its child elif isinstance(symbol, pybamm.Broadcast): - return symbol.reduce_one_dimension() # type:ignore + return symbol.reduce_one_dimension() # Average of a sum is sum of averages elif isinstance(symbol, (pybamm.Addition, pybamm.Subtraction)): return _sum_of_averages(symbol, yz_average) @@ -303,7 +303,9 @@ def r_average(symbol: pybamm.Symbol) -> pybamm.Symbol: return RAverage(symbol) -def size_average(symbol: pybamm.Symbol, f_a_dist=None) -> pybamm.Symbol: +def size_average( + symbol: pybamm.Symbol, f_a_dist: Optional[pybamm.Symbol] = None +) -> pybamm.Symbol: """Convenience function for averaging over particle size R using the area-weighted particle-size distribution. diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index b33f18bba0..2819f19738 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -8,14 +8,18 @@ import sympy from scipy.sparse import csr_matrix, issparse import functools -from typing import Union, Tuple, Optional import pybamm +from typing import Union, Tuple, Optional, Callable, overload +from typing_extensions import TypeVar + +# create type alias(s) +ChildValue = Union[float, np.ndarray, pybamm.Symbol] + def _preprocess_binary( - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, right: ChildValue ) -> Tuple[pybamm.Symbol, pybamm.Symbol]: if isinstance(left, numbers.Number): left = pybamm.Scalar(left) @@ -66,12 +70,9 @@ class BinaryOperator(pybamm.Symbol): """ def __init__( - self, - name: str, - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + self, name: str, left_child: ChildValue, right_child: ChildValue ) -> None: - left, right = _preprocess_binary(left, right) + left, right = _preprocess_binary(left_child, right_child) domains = self.get_children_domains([left, right]) super().__init__(name, children=[left, right], domains=domains) @@ -111,11 +112,7 @@ def create_copy(self): return out - def _binary_new_copy( - self, - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], - ): + def _binary_new_copy(self, left: ChildValue, right: ChildValue): """ Default behaviour for new_copy. This copies the behaviour of `_binary_evaluate`, but since `left` and `right` @@ -128,7 +125,7 @@ def evaluate( t: Optional[float] = None, y: Optional[np.ndarray] = None, y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + inputs: Optional[Union[dict, str]] = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" left = self.left.evaluate(t, y, y_dot, inputs) @@ -181,7 +178,11 @@ class Power(BinaryOperator): A node in the expression tree representing a `**` power operator. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("**", left, right) @@ -218,7 +219,7 @@ def _binary_evaluate( """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" # don't raise RuntimeWarning for NaNs with np.errstate(invalid="ignore"): - return left**right # type:ignore[operator] + return left**right class Addition(BinaryOperator): @@ -226,7 +227,11 @@ class Addition(BinaryOperator): A node in the expression tree representing an addition operator. """ - def __init__(self, left: pybamm.Symbol, right: pybamm.Symbol): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("+", left, right) @@ -252,7 +257,11 @@ class Subtraction(BinaryOperator): A node in the expression tree representing a subtraction operator. """ - def __init__(self, left: pybamm.Symbol, right: pybamm.Symbol): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("-", left, right) @@ -279,7 +288,11 @@ class Multiplication(BinaryOperator): matrix multiplication (e.g. scipy.sparse.coo.coo_matrix) """ - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("*", left, right) @@ -316,7 +329,11 @@ class MatrixMultiplication(BinaryOperator): A node in the expression tree representing a matrix multiplication operator. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("@", left, right) @@ -363,7 +380,11 @@ class Division(BinaryOperator): A node in the expression tree representing a division operator. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("/", left, right) @@ -407,7 +428,11 @@ class Inner(BinaryOperator): by a particular discretisation. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("inner product", left, right) @@ -441,8 +466,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.inner(left, right) @@ -452,9 +477,9 @@ def _evaluates_on_edges(self, dimension: str) -> bool: return False -def inner(left, right): +def inner(left_child, right_child): """Return inner product of two symbols.""" - left, right = _preprocess_binary(left, right) + left, right = _preprocess_binary(left_child, right_child) # simplify multiply by scalar zero, being careful about shape if pybamm.is_scalar_zero(left): return pybamm.zeros_like(right) @@ -480,7 +505,11 @@ class Equality(BinaryOperator): nodes. Returns 1 if the two nodes evaluate to the same thing and 0 otherwise. """ - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("==", left, right) @@ -506,8 +535,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.Equality(left, right) @@ -530,7 +559,12 @@ class _Heaviside(BinaryOperator): DISCONTINUITY event will automatically be added by the solver. """ - def __init__(self, name, left, right): + def __init__( + self, + name: str, + left: ChildValue, + right: ChildValue, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__(name, left, right) @@ -561,7 +595,11 @@ def _evaluate_for_shape(self): class EqualHeaviside(_Heaviside): """A heaviside function with equality (return 1 when left = right)""" - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("<=", left, right) @@ -579,7 +617,11 @@ def _binary_evaluate(self, left, right): class NotEqualHeaviside(_Heaviside): """A heaviside function without equality (return 0 when left = right)""" - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): super().__init__("<", left, right) def __str__(self): @@ -596,7 +638,11 @@ def _binary_evaluate(self, left, right): class Modulo(BinaryOperator): """Calculates the remainder of an integer division.""" - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): super().__init__("%", left, right) def _diff(self, variable: pybamm.Symbol): @@ -634,7 +680,11 @@ def _binary_evaluate(self, left, right): class Minimum(BinaryOperator): """Returns the smaller of two objects.""" - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): super().__init__("minimum", left, right) def __str__(self): @@ -660,8 +710,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.minimum(left, right) @@ -674,7 +724,11 @@ def _sympy_operator(self, left, right): class Maximum(BinaryOperator): """Returns the greater of two objects.""" - def __init__(self, left, right): + def __init__( + self, + left: ChildValue, + right: ChildValue, + ): super().__init__("maximum", left, right) def __str__(self): @@ -700,8 +754,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.maximum(left, right) @@ -712,10 +766,10 @@ def _sympy_operator(self, left, right): def _simplify_elementwise_binary_broadcasts( - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left_child: ChildValue, + right_child: ChildValue, ) -> Tuple[pybamm.Symbol, pybamm.Symbol]: - left, right = _preprocess_binary(left, right) + left, right = _preprocess_binary(left_child, right_child) def unpack_broadcast_recursive(symbol: pybamm.Symbol) -> pybamm.Symbol: if isinstance(symbol, pybamm.Broadcast): @@ -743,9 +797,9 @@ def unpack_broadcast_recursive(symbol: pybamm.Symbol) -> pybamm.Symbol: def _simplified_binary_broadcast_concatenation( - left: pybamm.Symbol, # Union[numbers.Number, pybamm.Symbol] - right: pybamm.Symbol, # Union[numbers.Number, pybamm.Symbol] - operator, + left: pybamm.Symbol, + right: pybamm.Symbol, + operator: Callable, ) -> Union[None, pybamm.Broadcast]: """ Check if there are concatenations or broadcasts that we can commute the operator @@ -788,8 +842,8 @@ def _simplified_binary_broadcast_concatenation( def simplified_power( - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ): left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -832,10 +886,7 @@ def simplified_power( return pybamm.simplify_if_constant(pybamm.Power(left, right)) -def add( - left: pybamm.Symbol, - right: pybamm.Symbol, -): +def add(left: ChildValue, right: ChildValue): """ Note ---- @@ -924,8 +975,8 @@ def add( def subtract( - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ): """ Note @@ -1009,8 +1060,8 @@ def subtract( def multiply( - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ): left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -1139,8 +1190,8 @@ def multiply( def divide( - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ): left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -1151,7 +1202,7 @@ def divide( # Move constant to always be on the left # For a division, this means (var / constant) becomes (1/constant * var) if right.is_constant() and not left.is_constant(): - return (1 / right) * left # type:ignore + return (1 / right) * left # Check for Concatenations and Broadcasts out = _simplified_binary_broadcast_concatenation(left, right, divide) @@ -1213,10 +1264,10 @@ def divide( def matmul( - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left_child: ChildValue, + right_child: ChildValue, ): - left, right = _preprocess_binary(left, right) + left, right = _preprocess_binary(left_child, right_child) if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): return pybamm.zeros_like(MatrixMultiplication(left, right)) @@ -1275,8 +1326,8 @@ def matmul( def minimum( - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ) -> pybamm.Symbol: """ Returns the smaller of two objects, possibly with a smoothing approximation. @@ -1284,23 +1335,23 @@ def minimum( """ # Check for Concatenations and Broadcasts left, right = _simplify_elementwise_binary_broadcasts(left, right) - out = _simplified_binary_broadcast_concatenation(left, right, minimum) - if out is not None: - return out + concat_out = _simplified_binary_broadcast_concatenation(left, right, minimum) + if concat_out is not None: + return concat_out k = pybamm.settings.min_smoothing # Return exact approximation if that is the setting or the outcome is a constant # (i.e. no need for smoothing) if k == "exact" or (left.is_constant() and right.is_constant()): - out = Minimum(left, right) # type:ignore + out = Minimum(left, right) else: - out = pybamm.softminus(left, right, k) # type:ignore - return pybamm.simplify_if_constant(out) # type:ignore + out = pybamm.softminus(left, right, k) + return pybamm.simplify_if_constant(out) def maximum( - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ): """ Returns the larger of two objects, possibly with a smoothing approximation. @@ -1308,25 +1359,21 @@ def maximum( """ # Check for Concatenations and Broadcasts left, right = _simplify_elementwise_binary_broadcasts(left, right) - out = _simplified_binary_broadcast_concatenation(left, right, maximum) - if out is not None: - return out + concat_out = _simplified_binary_broadcast_concatenation(left, right, maximum) + if concat_out is not None: + return concat_out k = pybamm.settings.max_smoothing # Return exact approximation if that is the setting or the outcome is a constant # (i.e. no need for smoothing) if k == "exact" or (left.is_constant() and right.is_constant()): - out = Maximum(left, right) # type:ignore + out = Maximum(left, right) else: - out = pybamm.softplus(left, right, k) # type:ignore - return pybamm.simplify_if_constant(out) # type:ignore + out = pybamm.softplus(left, right, k) + return pybamm.simplify_if_constant(out) -def _heaviside( - left: Union[float, pybamm.Symbol], - right: Union[float, pybamm.Symbol], - equal, -): +def _heaviside(left: ChildValue, right: ChildValue, equal): """return a :class:`EqualHeaviside` object, or a smooth approximation.""" # Check for Concatenations and Broadcasts left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -1357,12 +1404,12 @@ def _heaviside( # (i.e. no need for smoothing) if k == "exact" or (left.is_constant() and right.is_constant()): if equal is True: - out = pybamm.EqualHeaviside(left, right) # type:ignore + out = pybamm.EqualHeaviside(left, right) else: - out = pybamm.NotEqualHeaviside(left, right) # type:ignore + out = pybamm.NotEqualHeaviside(left, right) else: - out = pybamm.sigmoid(left, right, k) # type:ignore - return pybamm.simplify_if_constant(out) # type:ignore + out = pybamm.sigmoid(left, right, k) + return pybamm.simplify_if_constant(out) def softminus( @@ -1374,9 +1421,7 @@ def softminus( Softplus approximation to the minimum function. k is the smoothing parameter, set by `pybamm.settings.min_smoothing`. The recommended value is k=10. """ - return ( - pybamm.log(pybamm.exp(-k * left) + pybamm.exp(-k * right)) / -k # type:ignore - ) + return pybamm.log(pybamm.exp(-k * left) + pybamm.exp(-k * right)) / -k def softplus( @@ -1388,7 +1433,7 @@ def softplus( Softplus approximation to the maximum function. k is the smoothing parameter, set by `pybamm.settings.max_smoothing`. The recommended value is k=10. """ - return pybamm.log(pybamm.exp(k * left) + pybamm.exp(k * right)) / k # type:ignore + return pybamm.log(pybamm.exp(k * left) + pybamm.exp(k * right)) / k def sigmoid( @@ -1402,7 +1447,7 @@ def sigmoid( Note that the concept of deciding which side to pick when left=right does not apply for this smooth approximation. When left=right, the value is (left+right)/2. """ - return (1 + pybamm.tanh(k * (right - left))) / 2 # type:ignore + return (1 + pybamm.tanh(k * (right - left))) / 2 def source( @@ -1440,13 +1485,11 @@ def source( if isinstance(left, numbers.Number): left = pybamm.PrimaryBroadcast(left, "current collector") - if left.domain != ["current collector"] or right.domain != [ # type:ignore - "current collector" - ]: + if left.domain != ["current collector"] or right.domain != ["current collector"]: raise pybamm.DomainError( """'source' only implemented in the 'current collector' domain, but symbols have domains {} and {}""".format( - left.domain, right.domain # type:ignore + left.domain, right.domain ) ) if boundary: diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 2f148849fd..41bff28d8b 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -5,7 +5,9 @@ import numpy as np from scipy.sparse import csr_matrix -from typing import Sequence, Optional, Union +from typing import Sequence, Optional, Union, Type, SupportsFloat + +NumberType = Type[SupportsFloat] import pybamm @@ -31,7 +33,10 @@ class Broadcast(pybamm.SpatialOperator): """ def __init__( - self, child: pybamm.Symbol, domains: Sequence[str], name: Optional[str] = None + self, + child: pybamm.Symbol, + domains: dict[str, list[str]], + name: Optional[str] = None, ): if name is None: name = "broadcast" @@ -76,7 +81,7 @@ class PrimaryBroadcast(Broadcast): def __init__( self, child: Union[numbers.Number, pybamm.Symbol], - broadcast_domain: Sequence[str], + broadcast_domain: Union[str, list[str]], name: Optional[str] = None, ): # Convert child to scalar if it is a number @@ -91,9 +96,7 @@ def __init__( self.broadcast_type = "primary to nodes" super().__init__(child, domains, name=name) - def check_and_set_domains( - self, child: pybamm.Symbol, broadcast_domain: Sequence[str] - ): + def check_and_set_domains(self, child: pybamm.Symbol, broadcast_domain: list[str]): """See :meth:`Broadcast.check_and_set_domains`""" # Can only do primary broadcast from current collector to electrode, # particle-size or particle or from electrode to particle-size or particle. @@ -171,8 +174,8 @@ class PrimaryBroadcastToEdges(PrimaryBroadcast): def __init__( self, - child: pybamm.Symbol, - broadcast_domain: Sequence[str], + child: Union[numbers.Number, pybamm.Symbol], + broadcast_domain: Union[str, list[str]], name: Optional[str] = None, ): name = name or "broadcast to edges" @@ -208,7 +211,7 @@ class SecondaryBroadcast(Broadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Sequence[str], + broadcast_domain: Union[str, list[str]], name: Optional[str] = None, ): # Convert domain to list if it's a string @@ -220,9 +223,7 @@ def __init__( self.broadcast_type = "secondary to nodes" super().__init__(child, domains, name=name) - def check_and_set_domains( - self, child: pybamm.Symbol, broadcast_domain: Sequence[str] - ): + def check_and_set_domains(self, child: pybamm.Symbol, broadcast_domain: list[str]): """See :meth:`Broadcast.check_and_set_domains`""" if child.domain == []: raise TypeError( @@ -308,7 +309,7 @@ class SecondaryBroadcastToEdges(SecondaryBroadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Sequence[str], + broadcast_domain: Union[list[str], str], name: Optional[str] = None, ): name = name or "broadcast to edges" @@ -344,7 +345,7 @@ class TertiaryBroadcast(Broadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Sequence[str], + broadcast_domain: Union[list[str], str], name: Optional[str] = None, ): # Convert domain to list if it's a string @@ -357,7 +358,7 @@ def __init__( super().__init__(child, domains, name=name) def check_and_set_domains( - self, child: pybamm.Symbol, broadcast_domain: Sequence[str] + self, child: pybamm.Symbol, broadcast_domain: Union[list[str], str] ): """See :meth:`Broadcast.check_and_set_domains`""" if child.domains["secondary"] == []: @@ -429,7 +430,7 @@ class TertiaryBroadcastToEdges(TertiaryBroadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Sequence[str], + broadcast_domain: Union[list[str], str], name: Optional[str] = None, ): name = name or "broadcast to edges" @@ -445,15 +446,15 @@ class FullBroadcast(Broadcast): def __init__( self, - child, - broadcast_domain=None, - auxiliary_domains=None, - broadcast_domains=None, - name=None, + child: Union[NumberType, pybamm.Symbol], + broadcast_domain: Optional[Union[list[str], str]] = None, + auxiliary_domains: Optional[Union[str, dict]] = None, + broadcast_domains: Optional[dict] = None, + name: Optional[str] = None, ): # Convert child to scalar if it is a number if isinstance(child, numbers.Number): - child = pybamm.Scalar(child) + child: pybamm.Scalar = pybamm.Scalar(child) if isinstance(auxiliary_domains, str): auxiliary_domains = {"secondary": auxiliary_domains} @@ -466,7 +467,7 @@ def __init__( self.broadcast_type = "full to nodes" super().__init__(child, domains, name=name) - def check_and_set_domains(self, child, broadcast_domains): + def check_and_set_domains(self, child: pybamm.Symbol, broadcast_domains: dict): """See :meth:`Broadcast.check_and_set_domains`""" if broadcast_domains["primary"] == []: raise pybamm.DomainError( @@ -517,11 +518,11 @@ class FullBroadcastToEdges(FullBroadcast): def __init__( self, - child, - broadcast_domain=None, - auxiliary_domains=None, - broadcast_domains=None, - name=None, + child: Union[NumberType, pybamm.Symbol], + broadcast_domain: Optional[Union[list[str], str]] = None, + auxiliary_domains: Optional[Union[str, dict]] = None, + broadcast_domains: Optional[dict] = None, + name: Optional[str] = None, ): name = name or "broadcast to edges" super().__init__( @@ -548,7 +549,7 @@ def reduce_one_dimension(self): ) -def full_like(symbols: Sequence[pybamm.Symbol], fill_value: float) -> pybamm.Symbol: +def full_like(symbols: tuple[pybamm.Symbol, ...], fill_value: float) -> pybamm.Symbol: """ Returns an array with the same shape and domains as the sum of the input symbols, with a constant value given by `fill_value`. @@ -573,9 +574,9 @@ def full_like(symbols: Sequence[pybamm.Symbol], fill_value: float) -> pybamm.Sym shape = sum_symbol.shape # use vector or matrix if shape[1] == 1: - array_type = pybamm.Vector + array_type: Type[pybamm.Vector] = pybamm.Vector else: - array_type = pybamm.Matrix # type:ignore[assignment] + array_type: Type[pybamm.Matrix] = pybamm.Matrix # type:ignore[no-redef] # return dense array, except for a matrix of zeros if shape[1] != 1 and fill_value == 0: entries = csr_matrix(shape) diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 6eb1ebf868..32b1e64a56 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -8,10 +8,14 @@ import numpy as np import sympy from scipy.sparse import issparse, vstack -from typing import Optional, Iterable, Sequence, TYPE_CHECKING +from typing import Optional, Sequence, Type, Union, TYPE_CHECKING +from typing_extensions import TypeGuard, TypeVar import pybamm +if TYPE_CHECKING: + S = TypeVar("S", bound=pybamm.Symbol) # type: ignore[no-redef] + class Concatenation(pybamm.Symbol): """ @@ -44,12 +48,12 @@ def __init__( if name is None: name = "concatenation" if check_domain: - domains = self.get_children_domains(children) # type:ignore[arg-type] + domains = self.get_children_domains(children) else: domains = {"primary": []} self.concatenation_function = concat_fun - super().__init__(name, children, domains=domains) # type:ignore[arg-type] + super().__init__(name, children, domains=domains) def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" @@ -98,7 +102,7 @@ def get_children_domains(self, children: Sequence[pybamm.Symbol]): return domains - def _concatenation_evaluate(self, children_eval): + def _concatenation_evaluate(self, children_eval: list[np.ndarray]): """See :meth:`Concatenation._concatenation_evaluate()`.""" if len(children_eval) == 0: return np.array([]) @@ -110,7 +114,7 @@ def evaluate( t: Optional[float] = None, y: Optional[np.ndarray] = None, y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + inputs: Optional[Union[dict, str]] = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" children = self.children @@ -183,12 +187,12 @@ class NumpyConcatenation(Concatenation): """ def __init__(self, *children: Sequence[pybamm.Symbol]): - children = list(children) # type:ignore[assignment] + children = list(children) # Turn objects that evaluate to scalars to objects that evaluate to vectors, # so that we can concatenate them for i, child in enumerate(children): - if child.evaluates_to_number(): # type:ignore - children[i] = child * pybamm.Vector([1]) # type:ignore + if child.evaluates_to_number(): + children[i] = child * pybamm.Vector([1]) super().__init__( *children, name="numpy_concatenation", @@ -243,9 +247,7 @@ def __init__( children = list(children) # Allow the base class to sort the domains into the correct order - super().__init__( - *children, name="domain_concatenation" # type:ignore[arg-type] - ) # type:ignore[arg-type] + super().__init__(*children, name="domain_concatenation") if copy_this is None: # store mesh @@ -271,7 +273,7 @@ def __init__( self._children_slices = copy.copy(copy_this._children_slices) self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts - def _get_auxiliary_domain_repeats(self, auxiliary_domains): + def _get_auxiliary_domain_repeats(self, auxiliary_domains: dict) -> int: """Helper method to read the 'auxiliary_domain' meshes.""" mesh_pts = 1 for level, dom in auxiliary_domains.items(): @@ -283,7 +285,7 @@ def _get_auxiliary_domain_repeats(self, auxiliary_domains): def full_mesh(self): return self._full_mesh - def create_slices(self, node): + def create_slices(self, node: pybamm.Symbol) -> defaultdict: slices = defaultdict(list) start = 0 end = 0 @@ -300,7 +302,7 @@ def create_slices(self, node): start = end return slices - def _concatenation_evaluate(self, children_eval): + def _concatenation_evaluate(self, children_eval: list[np.ndarray]): """See :meth:`Concatenation._concatenation_evaluate()`.""" # preallocate vector vector = np.empty((self._size, 1)) @@ -329,7 +331,7 @@ def _concatenation_jac(self, children_jacs): jacs.append(pybamm.Index(child_jac, child_slice[i])) return SparseStack(*jacs) - def _concatenation_new_copy(self, children): + def _concatenation_new_copy(self, children: list[pybamm.Symbol]): """See :meth:`pybamm.Symbol.new_copy()`.""" new_symbol = simplified_domain_concatenation( children, self.full_mesh, copy_this=self @@ -482,7 +484,11 @@ def numpy_concatenation(*children): return simplified_numpy_concatenation(*children) -def simplified_domain_concatenation(children, mesh, copy_this=None): +def simplified_domain_concatenation( + children: list[pybamm.Symbol], + mesh: pybamm.Mesh, + copy_this: Optional[DomainConcatenation] = None, +): """Perform simplifications on a domain concatenation.""" # Create the DomainConcatenation to read domain and child domain concat = DomainConcatenation(children, mesh, copy_this=copy_this) @@ -510,7 +516,14 @@ def simplified_domain_concatenation(children, mesh, copy_this=None): return pybamm.simplify_if_constant(concat) -def domain_concatenation(children, mesh): +def domain_concatenation(children: list[pybamm.Symbol], mesh: pybamm.Mesh): """Helper function to create domain concatenations.""" # TODO: add option to turn off simplifications return simplified_domain_concatenation(children, mesh) + + +def all_children_are( + children: list[pybamm.Symbol], + class_type: Type[S], +) -> TypeGuard[list[S]]: + return all(isinstance(child, class_type) for child in children) diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index a40e0bf829..aec35fd736 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -4,11 +4,12 @@ from __future__ import annotations import numbers -import autograd # type: ignore +import autograd import numpy as np import sympy from scipy import special -from typing import Optional, Sequence +from typing import Optional, Sequence, Callable, Type, Union +from typing_extensions import TypeVar import pybamm @@ -34,11 +35,11 @@ class Function(pybamm.Symbol): def __init__( self, - function, + function: Callable, *children: pybamm.Symbol, name: Optional[str] = None, - derivative="autograd", - differentiated_function=None, + derivative: Optional[str] = "autograd", + differentiated_function: Optional[Callable] = None, ): # Turn numbers into scalars children = list(children) @@ -87,9 +88,11 @@ def diff(self, variable: pybamm.Symbol): # remove None entries partial_derivatives = [x for x in partial_derivatives if x is not None] - derivative = sum(partial_derivatives) # type:ignore[arg-type] + derivative: pybamm.Symbol = sum( + partial_derivatives + ) # type:ignore[assignment] if derivative == 0: - derivative = pybamm.Scalar(0) # type:ignore[assignment] + derivative = pybamm.Scalar(0) return derivative @@ -106,7 +109,7 @@ def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float): differentiated_function=self.function, ) elif self.derivative == "derivative": - if len(children) > 1: # type:ignore[arg-type] + if len(children) > 1: raise ValueError( """ differentiation using '.derivative()' not implemented for functions @@ -148,7 +151,7 @@ def evaluate( t: Optional[float] = None, y: Optional[np.ndarray] = None, y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + inputs: Optional[Union[dict, str]] = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" evaluated_children = [ @@ -194,7 +197,7 @@ def _function_new_copy(self, children: list) -> Function: : :pybamm.Function A new copy of the function """ - return pybamm.simplify_if_constant( # type:ignore[return-value] + return pybamm.simplify_if_constant( pybamm.Function( self.function, *children, @@ -220,7 +223,7 @@ def to_equation(self): return self._sympy_operator(*eq_list) -def simplified_function(func_class, child): +def simplified_function(func_class: Type[SF], child: pybamm.Symbol): """ Simplifications implemented before applying the function. Currently only implemented for one-child functions. @@ -249,7 +252,7 @@ class SpecificFunction(Function): The child to apply the function to """ - def __init__(self, function, child: pybamm.Symbol): + def __init__(self, function: Callable, child: pybamm.Symbol): super().__init__(function, child) def _function_new_copy(self, children): @@ -263,6 +266,9 @@ def _sympy_operator(self, child): return sympy_function(child) +SF = TypeVar("SF", bound=SpecificFunction) + + class Arcsinh(SpecificFunction): """Arcsinh function.""" @@ -278,7 +284,7 @@ def _sympy_operator(self, child): return sympy.asinh(child) -def arcsinh(child): +def arcsinh(child: pybamm.Symbol): """Returns arcsinh function of child.""" return simplified_function(Arcsinh, child) @@ -298,7 +304,7 @@ def _sympy_operator(self, child): return sympy.atan(child) -def arctan(child): +def arctan(child: pybamm.Symbol): """Returns hyperbolic tan function of child.""" return simplified_function(Arctan, child) @@ -314,7 +320,7 @@ def _function_diff(self, children, idx): return -sin(children[0]) -def cos(child): +def cos(child: pybamm.Symbol): """Returns cosine function of child.""" return simplified_function(Cos, child) @@ -330,7 +336,7 @@ def _function_diff(self, children, idx): return sinh(children[0]) -def cosh(child): +def cosh(child: pybamm.Symbol): """Returns hyperbolic cosine function of child.""" return simplified_function(Cosh, child) @@ -346,12 +352,12 @@ def _function_diff(self, children, idx): return 2 / np.sqrt(np.pi) * exp(-children[0] ** 2) -def erf(child): +def erf(child: pybamm.Symbol): """Returns error function of child.""" return simplified_function(Erf, child) -def erfc(child): +def erfc(child: pybamm.Symbol): """Returns complementary error function of child.""" return 1 - simplified_function(Erf, child) @@ -367,7 +373,7 @@ def _function_diff(self, children, idx): return exp(children[0]) -def exp(child): +def exp(child: pybamm.Symbol): """Returns exponential function of child.""" return simplified_function(Exp, child) @@ -397,7 +403,7 @@ def log(child, base="e"): return log_child / np.log(base) -def log10(child): +def log10(child: pybamm.Symbol): """Returns logarithmic function of child, with base 10.""" return log(child, base=10) @@ -414,7 +420,7 @@ def _evaluate_for_shape(self): return np.nan * np.ones((1, 1)) -def max(child): +def max(child: pybamm.Symbol): """ Returns max function of child. Not to be confused with :meth:`pybamm.maximum`, which returns the larger of two objects. @@ -434,7 +440,7 @@ def _evaluate_for_shape(self): return np.nan * np.ones((1, 1)) -def min(child): +def min(child: pybamm.Symbol): """ Returns min function of child. Not to be confused with :meth:`pybamm.minimum`, which returns the smaller of two objects. @@ -442,7 +448,7 @@ def min(child): return pybamm.simplify_if_constant(Min(child)) -def sech(child): +def sech(child: pybamm.Symbol): """Returns hyperbolic sec function of child.""" return 1 / simplified_function(Cosh, child) @@ -458,7 +464,7 @@ def _function_diff(self, children, idx): return cos(children[0]) -def sin(child): +def sin(child: pybamm.Symbol): """Returns sine function of child.""" return simplified_function(Sin, child) @@ -474,7 +480,7 @@ def _function_diff(self, children, idx): return cosh(children[0]) -def sinh(child): +def sinh(child: pybamm.Symbol): """Returns hyperbolic sine function of child.""" return simplified_function(Sinh, child) @@ -495,7 +501,7 @@ def _function_diff(self, children, idx): return 1 / (2 * sqrt(children[0])) -def sqrt(child): +def sqrt(child: pybamm.Symbol): """Returns square root function of child.""" return simplified_function(Sqrt, child) @@ -511,6 +517,6 @@ def _function_diff(self, children, idx): return sech(children[0]) ** 2 -def tanh(child): +def tanh(child: pybamm.Symbol): """Returns hyperbolic tan function of child.""" return simplified_function(Tanh, child) diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 510e19a2b2..f2a38617f5 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -65,7 +65,6 @@ class Time(IndependentVariable): def __init__(self): super().__init__("time") - # making this super(pybamm.Symbol, self)__init__(name="time") works, but not sure why. def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" @@ -74,9 +73,9 @@ def create_copy(self): def _base_evaluate( self, t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + y: Any = None, + y_dot: Any = None, + inputs: Any = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if t is None: @@ -115,8 +114,6 @@ class SpatialVariable(IndependentVariable): deprecated. """ - # coord_sys: Optional[Any] - def __init__( self, name: str, @@ -130,28 +127,25 @@ def __init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains ) domain = self.domain - # using a dataclass, at this point the domain doesn't get set for some reason, during initialisation. if domain == []: raise ValueError("domain must be provided") # Check symbol name vs domain name - if name == "r_n" and not all( - n in domain[0] for n in ["negative", "particle"] # type:ignore[index] - ): + if name == "r_n" and not all(n in domain[0] for n in ["negative", "particle"]): # catches "negative particle", "negative secondary particle", etc raise pybamm.DomainError( "domain must be negative particle if name is 'r_n'" ) elif name == "r_p" and not all( - n in domain[0] for n in ["positive", "particle"] # type:ignore[index] + n in domain[0] for n in ["positive", "particle"] ): # catches "positive particle", "positive secondary particle", etc raise pybamm.DomainError( "domain must be positive particle if name is 'r_p'" ) elif name in ["x", "y", "z", "x_n", "x_s", "x_p"] and any( - ["particle" in dom for dom in domain] # type:ignore[index, union-attr] + ["particle" in dom for dom in domain] ): raise pybamm.DomainError( "domain cannot be particle if name is '{}'".format(name) diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index f1d2795fdc..536c818ca7 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -31,7 +31,7 @@ class InputParameter(pybamm.Symbol): def __init__( self, name: str, - domain: Optional[Union[Sequence[str], str]] = None, + domain: Optional[Union[list[str], str]] = None, expected_size: Optional[int] = None, ) -> None: # Expected size defaults to 1 if no domain else None (gets set later) @@ -78,7 +78,7 @@ def _base_evaluate( t: Optional[float] = None, y: Optional[np.ndarray] = None, y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + inputs: Optional[Union[dict, str]] = None, ): # inputs should be a dictionary # convert 'None' to empty dictionary for more informative error diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index be14b888b8..3182100e6e 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -103,7 +103,7 @@ def __init__( x1 = x[0] else: x1 = x - x = [x] # type:ignore[list-item] + x = [x] x2 = None if x1.shape[0] != y.shape[0]: raise ValueError( @@ -115,7 +115,7 @@ def __init__( children = [children] # Either a single x is provided and there is one child # or x is a 2-tuple and there are two children - if len(x) != len(children): # type:ignore[arg-type] + if len(x) != len(children): raise ValueError("len(x) should equal len(children)") # if there is only one x, y can be 2-dimensional but the child must have # length 1 @@ -131,7 +131,7 @@ def __init__( if extrapolate is False: fill_value = np.nan elif extrapolate is True: - fill_value = "extrapolate" # type:ignore[assignment] + fill_value = "extrapolate" # ignore:assignment interpolating_function = interpolate.interp1d( x1, y.T, @@ -181,7 +181,7 @@ def __init__( ) else: interpolating_function = interpolate.RegularGridInterpolator( - (x1, x2, x3), # type:ignore[has-type] + (x1, x2, x3), y, method=interpolator, bounds_error=False, diff --git a/pybamm/expression_tree/matrix.py b/pybamm/expression_tree/matrix.py index 68225dd9a8..ab9c28c392 100644 --- a/pybamm/expression_tree/matrix.py +++ b/pybamm/expression_tree/matrix.py @@ -3,7 +3,7 @@ # import numpy as np from scipy.sparse import csr_matrix, issparse -from typing import Union, Optional +from typing import Union, Optional, Type import pybamm @@ -15,7 +15,7 @@ class Matrix(pybamm.Array): def __init__( self, - entries: Union[np.ndarray, list], + entries: Union[np.ndarray, list, csr_matrix], name: Optional[str] = None, domain: Optional[list[str]] = None, auxiliary_domains: Optional[dict[str, str]] = None, diff --git a/pybamm/expression_tree/operations/jacobian.py b/pybamm/expression_tree/operations/jacobian.py index 6eeca38838..57d603be94 100644 --- a/pybamm/expression_tree/operations/jacobian.py +++ b/pybamm/expression_tree/operations/jacobian.py @@ -76,17 +76,12 @@ def _jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol): elif isinstance(symbol, pybamm.Function): children_jacs = [None] * len(symbol.children) for i, child in enumerate(symbol.children): - children_jacs[i] = self.jac( # type:ignore[call-overload] - child, variable - ) + children_jacs[i] = self.jac(child, variable) # _function_jac defined in function class jac = symbol._function_jac(children_jacs) elif isinstance(symbol, pybamm.Concatenation): - children_jacs = [ - self.jac(child, variable) # type:ignore[misc] - for child in symbol.children - ] + children_jacs = [self.jac(child, variable) for child in symbol.children] if len(children_jacs) == 1: jac = children_jacs[0] else: diff --git a/pybamm/expression_tree/operations/unpack_symbols.py b/pybamm/expression_tree/operations/unpack_symbols.py index 5ebfbf4ef7..73b23e1ed6 100644 --- a/pybamm/expression_tree/operations/unpack_symbols.py +++ b/pybamm/expression_tree/operations/unpack_symbols.py @@ -27,7 +27,7 @@ def __init__( unpacked_symbols: Optional[dict] = None, ): self.classes_to_find = classes_to_find - self._unpacked_symbols: dict = unpacked_symbols or {} # type:ignore[assignment] + self._unpacked_symbols: dict = unpacked_symbols or {} def unpack_list_of_symbols( self, list_of_symbols: Sequence[pybamm.Symbol] diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index e4e5914263..3ab8bdd252 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -101,7 +101,7 @@ def __init__( domains = self.get_children_domains(children_list) super().__init__(name, children=children_list, domains=domains) - self.input_names = list(inputs.keys()) # type:ignore[misc] + self.input_names = list(inputs.keys()) # Use the inspect module to find the function's "short name" from the # Parameters module that called it @@ -109,12 +109,12 @@ def __init__( self.print_name = print_name else: frame = sys._getframe().f_back - print_name = frame.f_code.co_name # type:ignore[union-attr] + print_name = frame.f_code.co_name if print_name.startswith("_"): self.print_name = None else: try: - parent_param = frame.f_locals["self"] # type:ignore[union-attr] + parent_param = frame.f_locals["self"] except KeyError: parent_param = None if hasattr(parent_param, "domain") and parent_param.domain is not None: @@ -124,16 +124,16 @@ def __init__( print_name += f"_{d}" self.print_name = print_name - @property - def input_names(self): - return self._input_names - def print_input_names(self): if self._input_names: for inp in self._input_names: print(inp) - @input_names.setter # type:ignore[no-redef, attr-defined] + @property + def input_names(self): + return self._input_names + + @input_names.setter def input_names(self, inp=None): if inp: if inp.__class__ is list: diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index efd726adbb..4b351f5f08 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -5,7 +5,7 @@ import numbers import numpy as np import sympy -from typing import Optional, Literal, Union +from typing import Optional, Literal, Union, Any import pybamm @@ -26,7 +26,7 @@ class Scalar(pybamm.Symbol): def __init__( self, - value: Union[float, numbers.Number], + value: Union[float, numbers.Number, np.bool_], name: Optional[str] = None, ) -> None: # set default name if not provided @@ -61,10 +61,10 @@ def set_id(self): def _base_evaluate( self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + t: Any = None, + y: Any = None, + y_dot: Any = None, + inputs: Any = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" return self._value diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 11eaa104c2..841b54a4ed 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -4,7 +4,7 @@ from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, vstack -from typing import Optional, Iterable, Union, Sequence +from typing import Optional, Union, Any import pybamm @@ -39,7 +39,7 @@ def __init__( *y_slices: slice, base_name="y", name: Optional[str] = None, - domain: Optional[Sequence[str]] = None, + domain: Optional[Union[list[str], str]] = None, auxiliary_domains: Optional[dict] = None, domains: Optional[dict[str, list[str]]] = None, evaluation_array: Optional[list] = None, @@ -226,7 +226,7 @@ def __init__( self, *y_slices: slice, name: Optional[str] = None, - domain: Optional[Sequence[str]] = None, + domain: Optional[Union[list[str], str]] = None, auxiliary_domains: Optional[dict] = None, domains: Optional[dict[str, list[str]]] = None, evaluation_array: Optional[list] = None, @@ -243,10 +243,10 @@ def __init__( def _base_evaluate( self, - t: Optional[float] = None, + t: Any = None, y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + y_dot: Any = None, + inputs: Any = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if y is None: @@ -310,7 +310,7 @@ def __init__( self, *y_slices: slice, name: Optional[str] = None, - domain: Optional[Sequence[str]] = None, + domain: Optional[Union[list[str], str]] = None, auxiliary_domains: Optional[dict] = None, domains: Optional[dict[str, list[str]]] = None, evaluation_array: Optional[list] = None, @@ -327,10 +327,10 @@ def __init__( def _base_evaluate( self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, + t: Any = None, + y: Any = None, y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + inputs: Any = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if y_dot is None: diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 80f0f5f4ed..fc317a2cce 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -3,7 +3,6 @@ # from __future__ import annotations import numbers -import inspect import anytree import numpy as np @@ -17,12 +16,9 @@ Optional, Iterable, Sequence, - TypeVar, no_type_check, ) - -# from pydantic import BaseModel -# from pydantic.dataclasses import dataclass +from typing_extensions import TypeVar import pybamm from pybamm.expression_tree.printing.print_name import prettify_print_name @@ -42,27 +38,6 @@ EMPTY_DOMAINS: dict[str, list] = {k: [] for k in DOMAIN_LEVELS} -# class PatchedModel(BaseModel): -# @no_type_check -# def __setattr__(self, name, value): -# """ -# To be able to use properties with setters -# """ -# try: -# super().__setattr__(name, value) -# except ValueError as e: -# setters = inspect.getmembers( -# self.__class__, -# predicate=lambda x: isinstance(x, property) and x.fset is not None, -# ) -# for setter_name, func in setters: -# if setter_name == name: -# object.__setattr__(self, name, value) -# break -# else: -# raise e - - def domain_size(domain: Union[list[str], str]): """ Get the domain size. @@ -147,7 +122,7 @@ def is_scalar_minus_one(expr: Symbol): return is_scalar_x(expr, -1) -def is_matrix_x(expr: Symbol, x): +def is_matrix_x(expr: Symbol, x: int): """ Utility function to test if an expression evaluates to a constant matrix value """ @@ -191,9 +166,7 @@ def is_matrix_minus_one(expr: Symbol): return is_matrix_x(expr, -1) -def simplify_if_constant( - symbol: S, -) -> S: +def simplify_if_constant(symbol: Union[S, float]) -> S: """ Utility function to simplify an expression tree if it evalutes to a constant scalar, vector or matrix @@ -206,25 +179,20 @@ def simplify_if_constant( or (isinstance(result, np.ndarray) and result.ndim == 0) or isinstance(result, np.bool_) ): - return pybamm.Scalar(result) # type:ignore[return-value, arg-type] + return pybamm.Scalar(result) elif isinstance(result, np.ndarray) or issparse(result): if result.ndim == 1 or result.shape[1] == 1: - return pybamm.Vector( # type:ignore[return-value] - result, domains=symbol.domains - ) + return pybamm.Vector(result, domains=symbol.domains) else: # Turn matrix of zeros into sparse matrix if isinstance(result, np.ndarray) and np.all(result == 0): result = csr_matrix(result) - return pybamm.Matrix( # type:ignore[return-value] - result, domains=symbol.domains - ) + return pybamm.Matrix(result, domains=symbol.domains) return symbol -# @dataclass -class Symbol: # PatchedModel +class Symbol: """ Base node class for the expression tree. @@ -253,23 +221,14 @@ class Symbol: # PatchedModel deprecated. """ - # name: str - def __init__( self, name: str, children: Optional[Sequence[Symbol]] = None, - domain: Optional[Union[Sequence[str], str]] = None, + domain: Optional[Union[list[str], str]] = None, auxiliary_domains: Optional[dict[str, str]] = None, - domains: Optional[dict] = None, + domains: Optional[dict[str, list[str]]] = None, ): - # super().__init__( - # name=name, - # children=children, - # domain=domain, - # auxiliary_domains=auxiliary_domains, - # domains=domains, - # ) super(Symbol, self).__init__() self.name = name @@ -277,11 +236,11 @@ def __init__( children = [] self._children = children - # Keep a separate "oprhans" attribute for backwards compatibility + # Keep a separate "orphans" attribute for backwards compatibility self._orphans = children # Set domains (and hence id) - self.domains = self.read_domain_or_domains(domain, auxiliary_domains, domains) # type: ignore[misc] + self.domains = self.read_domain_or_domains(domain, auxiliary_domains, domains) self._saved_evaluates_on_edges: dict = {} self._print_name = None @@ -296,15 +255,6 @@ def __init__( ): self.test_shape() - # # super().__init__(name, children, domain, auxiliary_domains, domains) - - # class Config: - # arbitrary_types_allowed = True - # # underscore_attrs_are_private = True - # keep_untouched = (cached_property,) - # fields = {"domain": {"exclude": True}, "auxiliary_domains": {"exclude": True}} - # # json_encoders = {"Symbol": lambda u: u.__dict__} - @property def children(self): """ @@ -329,31 +279,7 @@ def name(self, value: str): def domains(self): return self._domains - @property - def domain(self): - """ - list of applicable domains. - - Returns - ------- - iterable of str - """ - return self._domains["primary"] - - @domain.setter - def domain(self, domain: Union[list[str], str]): - raise NotImplementedError( - "Cannot set domain directly, use domains={'primary': domain} instead" - ) - - @property - def auxiliary_domains(self): - """Returns auxiliary domains.""" - raise NotImplementedError( - "symbol.auxiliary_domains has been deprecated, use symbol.domains instead" - ) - - @domains.setter # type:ignore[no-redef, attr-defined] + @domains.setter def domains(self, domains): try: if ( @@ -398,6 +324,30 @@ def domains(self, domains): self._domains = domains self.set_id() + @property + def domain(self): + """ + list of applicable domains. + + Returns + ------- + iterable of str + """ + return self._domains["primary"] + + @domain.setter + def domain(self, domain): + raise NotImplementedError( + "Cannot set domain directly, use domains={'primary': domain} instead" + ) + + @property + def auxiliary_domains(self): + """Returns auxiliary domains.""" + raise NotImplementedError( + "symbol.auxiliary_domains has been deprecated, use symbol.domains instead" + ) + @property def secondary_domain(self): """Helper function to get the secondary domain of a symbol.""" @@ -448,7 +398,7 @@ def get_children_domains(self, children: Sequence[Symbol]): def read_domain_or_domains( self, - domain: Optional[Union[Sequence[str], str]], + domain: Optional[Union[list[str], str]], auxiliary_domains: Optional[dict[str, str]], domains: Optional[dict], ): @@ -494,9 +444,9 @@ def scale(self): def reference(self): return self._reference - def __eq__(self, other): + def __eq__(self, other: Union[Symbol, float]): try: - return self._id == other._id + return self._id == other._id # type:ignore except AttributeError: if isinstance(other, numbers.Number): return self._id == pybamm.Scalar(other)._id @@ -550,7 +500,7 @@ def visualise(self, filename: str): # raise error but only through logger so that test passes pybamm.logger.error("Please install graphviz>=2.42.2 to use dot exporter") - def relabel_tree(self, symbol, counter): + def relabel_tree(self, symbol: Symbol, counter: int): """ Finds all children of a symbol and assigns them a new id so that they can be visualised properly using the graphviz output @@ -638,11 +588,11 @@ def __rmul__(self, other: Symbol) -> Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(other, self) - def __matmul__(self, other: Symbol): + def __matmul__(self, other: Symbol) -> pybamm.MatrixMultiplication: """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(self, other) - def __rmatmul__(self, other: Symbol): + def __rmatmul__(self, other: Symbol) -> pybamm.MatrixMultiplication: """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(other, self) @@ -678,7 +628,7 @@ def __ge__(self, other: Symbol) -> pybamm.EqualHeaviside: """return a :class:`EqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(other, self, True) - def __neg__(self) -> pybamm.Symbol: + def __neg__(self) -> pybamm.Negate: """return a :class:`Negate` object.""" if isinstance(self, pybamm.Negate): # Double negative is a positive @@ -791,7 +741,7 @@ def _base_evaluate( t: Optional[float] = None, y: Optional[np.ndarray] = None, y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + inputs: Optional[Union[dict, str]] = None, ): """ evaluate expression tree. @@ -823,8 +773,8 @@ def evaluate( t: Optional[float] = None, y: Optional[np.ndarray] = None, y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, - ): + inputs: Optional[Union[dict, str]] = None, + ) -> Union[float, np.ndarray]: """Evaluate expression tree (wrapper to allow using dict of known values). Parameters @@ -876,7 +826,7 @@ def is_constant(self): # Default behaviour is False return False - def evaluate_ignoring_errors(self, t=0): + def evaluate_ignoring_errors(self, t: Union[None, float] = 0): # none """ Evaluates the expression. If a node exists in the tree that cannot be evaluated as a scalar or vector (e.g. Time, Parameter, Variable, StateVector), then None @@ -955,7 +905,7 @@ def _evaluates_on_edges(self, dimension): # Default behaviour: return False return False - def has_symbol_of_classes(self, symbol_classes): + def has_symbol_of_classes(self, symbol_classes: Union[Symbol, tuple[type[Symbol]]]): """ Returns True if equation has a term of the class(es) `symbol_class`. diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index c441645f10..cf6ad3f767 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -28,7 +28,7 @@ class UnaryOperator(pybamm.Symbol): child node """ - def __init__(self, name: str, child: pybamm.Symbol, domains=None): + def __init__(self, name: str, child: pybamm.Symbol, domains: Optional[dict] = None): if isinstance(child, numbers.Number): child = pybamm.Scalar(child) domains = domains or child.domains @@ -344,7 +344,7 @@ class with a :class:`Matrix` child node """ - def __init__(self, name: str, child: pybamm.Symbol, domains=None): + def __init__(self, name: str, child: pybamm.Symbol, domains: Optional[dict] = None): super().__init__(name, child, domains) diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index ba373f8ec8..34fbdbec60 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -6,7 +6,7 @@ import sympy import numbers import pybamm -from typing import Iterable, Union, Optional, Sequence +from typing import Union, Optional class VariableBase(pybamm.Symbol): @@ -51,10 +51,10 @@ class VariableBase(pybamm.Symbol): def __init__( self, name: str, - domain: Optional[Union[Sequence[str], str]] = None, + domain: Optional[Union[list[str], str]] = None, auxiliary_domains: Optional[dict] = None, domains: Optional[dict] = None, - bounds: Optional[tuple] = None, + bounds: Optional[tuple[pybamm.Symbol]] = None, print_name: Optional[str] = None, scale: Optional[Union[float, pybamm.Symbol]] = 1, reference: Optional[Union[float, pybamm.Symbol]] = 0, @@ -83,13 +83,13 @@ def bounds(self): return self._bounds @bounds.setter - def bounds(self, values: Sequence[numbers.Number]): + def bounds(self, values: tuple[numbers.Number, numbers.Number]): if values is None: values = (-np.inf, np.inf) else: if ( all(isinstance(b, numbers.Number) for b in values) - and values[0] >= values[1] # type:ignore + and values[0] >= values[1] ): raise ValueError( f"Invalid bounds {values}. " @@ -99,7 +99,7 @@ def bounds(self, values: Sequence[numbers.Number]): values = list(values) for idx, bound in enumerate(values): if isinstance(bound, numbers.Number): - values[idx] = pybamm.Scalar(bound) # type:ignore[call-overload] + values[idx] = pybamm.Scalar(bound) self._bounds = tuple(values) def set_id(self): @@ -234,9 +234,7 @@ def get_variable(self) -> pybamm.Variable: """ return Variable(self.name[:-1], domains=self.domains, scale=self.scale) - def diff( - self, variable: pybamm.VariableDot # type:ignore[override] - ) -> pybamm.Scalar: + def diff(self, variable: pybamm.Symbol) -> pybamm.Scalar: if variable == self: return pybamm.Scalar(1) elif variable == pybamm.t: From f2faa8a4683201cd9e2c09d8e2901d1e57e94b04 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Dec 2023 10:43:15 +0000 Subject: [PATCH 07/32] style: pre-commit fixes --- mypy.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index cb47be594a..13e0cad69c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -31,4 +31,4 @@ ignore_missing_imports=True ignore_missing_imports=True [mypy-absl.*] -ignore_missing_imports=True \ No newline at end of file +ignore_missing_imports=True From da3b072581a21ee96b2ecfb98a17ed92a9feab7b Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 1 Dec 2023 18:08:48 +0000 Subject: [PATCH 08/32] edit _from_json to fix subtype incompatibility mypy errors --- pybamm/expression_tree/array.py | 6 +- pybamm/expression_tree/concatenations.py | 29 +++++----- pybamm/expression_tree/functions.py | 45 ++++++++++----- .../expression_tree/independent_variable.py | 6 +- pybamm/expression_tree/input_parameter.py | 6 +- pybamm/expression_tree/interpolant.py | 15 +++-- pybamm/expression_tree/scalar.py | 6 +- pybamm/expression_tree/state_vector.py | 13 +++-- pybamm/expression_tree/symbol.py | 39 +++++++------ pybamm/expression_tree/unary_operators.py | 55 +++++++++++++++---- pybamm/models/event.py | 13 ++--- 11 files changed, 138 insertions(+), 95 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 651ea37f2d..a3001507d1 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -61,8 +61,6 @@ def __init__( @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - if isinstance(snippet["entries"], dict): matrix = csr_matrix( ( @@ -75,14 +73,12 @@ def _from_json(cls, snippet: dict): else: matrix = snippet["entries"] - instance.__init__( + return cls( matrix, name=snippet["name"], domains=snippet["domains"], ) - return instance - @property def entries(self): return self._entries diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 5f9bc4dced..a567b3097d 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -56,13 +56,15 @@ def __init__( super().__init__(name, children, domains=domains) @classmethod - def _from_json(cls, *children, name, domains, concat_fun=None): + def _from_json(cls, snippet: dict): """Creates a new Concatenation instance from a json object""" instance = cls.__new__(cls) - instance.concatenation_function = concat_fun + instance.concatenation_function = snippet["concat_fun"] - super(Concatenation, instance).__init__(name, children, domains=domains) + super(Concatenation, instance).__init__( + snippet["name"], tuple(snippet["children"]), domains=snippet["domains"] + ) return instance @@ -215,12 +217,11 @@ def __init__(self, *children: Sequence[pybamm.Symbol]): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.Concatenation._from_json()`.""" - instance = super()._from_json( - *snippet["children"], - name="numpy_concatenation", - domains=snippet["domains"], - concat_fun=np.concatenate, - ) + + snippet["name"] = "numpy_concatenation" + snippet["concat_fun"] = np.concatenate + + instance = super()._from_json(snippet) return instance @@ -300,11 +301,11 @@ def __init__( @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.Concatenation._from_json()`.""" - instance = super()._from_json( - *snippet["children"], - name="domain_concatenation", - domains=snippet["domains"], - ) + + snippet["name"] = "domain_concatenation" + snippet["concat_fun"] = None + + instance = super()._from_json(snippet) def repack_defaultDict(slices): slices = defaultdict(list, slices) diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 279b81cafb..31acf31c94 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -268,7 +268,7 @@ def __init__(self, function: Callable, child: pybamm.Symbol): super().__init__(function, child) @classmethod - def _from_json(cls, function: Callable, snippet: dict): + def _from_json(cls, snippet: dict): """ Reconstructs a SpecificFunction instance during deserialisation of a JSON file. @@ -282,7 +282,9 @@ def _from_json(cls, function: Callable, snippet: dict): instance = cls.__new__(cls) - super(SpecificFunction, instance).__init__(function, snippet["children"][0]) + super(SpecificFunction, instance).__init__( + snippet["function"], snippet["children"][0] + ) return instance @@ -323,7 +325,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.arcsinh, snippet) + snippet["function"] = np.arcsinh + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -350,7 +353,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.arctan, snippet) + snippet["function"] = np.arctan + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -377,7 +381,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.cos, snippet) + snippet["function"] = np.cos + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -399,7 +404,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.cosh, snippet) + snippet["function"] = np.cosh + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -421,7 +427,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(special.erf, snippet) + snippet["function"] = special.erf + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -448,7 +455,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.exp, snippet) + snippet["function"] = np.exp + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -470,7 +478,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.log, snippet) + snippet["function"] = np.log + instance = super()._from_json(snippet) return instance def _function_evaluate(self, evaluated_children): @@ -506,7 +515,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.max, snippet) + snippet["function"] = np.max + instance = super()._from_json(snippet) return instance def _evaluate_for_shape(self): @@ -532,7 +542,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.min, snippet) + snippet["function"] = np.min + instance = super()._from_json(snippet) return instance def _evaluate_for_shape(self): @@ -563,7 +574,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.sin, snippet) + snippet["function"] = np.sin + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -585,7 +597,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.sinh, snippet) + snippet["function"] = np.sinh + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): @@ -607,7 +620,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.sqrt, snippet) + snippet["function"] = np.sqrt + instance = super()._from_json(snippet) return instance def _function_evaluate(self, evaluated_children): @@ -634,7 +648,8 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.SpecificFunction._from_json()`.""" - instance = super()._from_json(np.tanh, snippet) + snippet["function"] = np.tanh + instance = super()._from_json(snippet) return instance def _function_diff(self, children, idx): diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 0c3f310bb4..6f8d9315a3 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -45,11 +45,7 @@ def __init__( @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - - instance.__init__(snippet["name"], domains=snippet["domains"]) - - return instance + return cls(snippet["name"], domains=snippet["domains"]) def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index b4b723dee3..1fcddd7744 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -45,16 +45,12 @@ def __init__( @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - - instance.__init__( + return cls( snippet["name"], domain=snippet["domain"], expected_size=snippet["expected_size"], ) - return instance - def create_copy(self) -> pybamm.InputParameter: """See :meth:`pybamm.Symbol.new_copy()`.""" new_input_parameter = InputParameter( diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 5300bfbac8..9664462c35 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -131,7 +131,7 @@ def __init__( if extrapolate is False: fill_value = np.nan elif extrapolate is True: - fill_value = "extrapolate" # ignore:assignment + fill_value = "extrapolate" # ignore: assignment interpolating_function = interpolate.interp1d( x1, y.T, @@ -207,14 +207,13 @@ def __init__( @classmethod def _from_json(cls, snippet: dict): """Create an Interpolant object from JSON data""" - instance = cls.__new__(cls) if len(snippet["x"]) == 1: x = [np.array(x) for x in snippet["x"]] else: x = tuple(np.array(x) for x in snippet["x"]) - instance.__init__( + return cls( x, np.array(snippet["y"]), snippet["children"], @@ -223,8 +222,6 @@ def _from_json(cls, snippet: dict): extrapolate=snippet["extrapolate"], ) - return instance - @property def entries_string(self): return self._entries_string @@ -245,7 +242,13 @@ def entries_string(self, value): def set_id(self): """See :meth:`pybamm.Symbol.set_id()`.""" self._id = hash( - (self.__class__, self.name, self.entries_string, *tuple([child.id for child in self.children]), *tuple(self.domain)) + ( + self.__class__, + self.name, + self.entries_string, + *tuple([child.id for child in self.children]), + *tuple(self.domain), + ) ) def _function_new_copy(self, children): diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index f9e15c0384..d99b228f74 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -38,11 +38,7 @@ def __init__( @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - - instance.__init__(snippet["value"], name=snippet["name"]) - - return instance + return cls(snippet["value"], name=snippet["name"]) def __str__(self): return str(self.value) diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index ccb78ef493..8cd23b9f52 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -77,19 +77,15 @@ def __init__( @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - y_slices = [slice(s["start"], s["stop"], s["step"]) for s in snippet["y_slice"]] - instance.__init__( + return cls( *y_slices, name=snippet["name"], domains=snippet["domains"], evaluation_array=snippet["evaluation_array"], ) - return instance - @property def y_slices(self): return self._y_slices @@ -124,7 +120,12 @@ def set_evaluation_array(self, y_slices, evaluation_array): def set_id(self): """See :meth:`pybamm.Symbol.set_id()`""" self._id = hash( - (self.__class__, self.name, tuple(self.evaluation_array), *tuple(self.domain)) + ( + self.__class__, + self.name, + tuple(self.evaluation_array), + *tuple(self.domain), + ) ) def _jac_diff_vector(self, variable: pybamm.StateVectorBase): diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 13174c5ab3..ac179b6cdc 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -265,14 +265,10 @@ def _from_json(cls, snippet: dict): At minimum, should contain "name", "children" and "domains". """ - instance = cls.__new__(cls) - - instance.__init__( + return cls( snippet["name"], children=snippet["children"], domains=snippet["domains"] ) - return instance - @property def children(self): """ @@ -449,7 +445,12 @@ def set_id(self): need to hash once. """ self._id = hash( - (self.__class__, self.name, *tuple([child.id for child in self.children]), *tuple([(k, tuple(v)) for k, v in self.domains.items() if v != []])) + ( + self.__class__, + self.name, + *tuple([child.id for child in self.children]), + *tuple([(k, tuple(v)) for k, v in self.domains.items() if v != []]), + ) ) @property @@ -584,47 +585,51 @@ def __repr__(self): {k: v for k, v in self.domains.items() if v != []}, ) - def __add__(self, other: Symbol) -> Addition: + def __add__(self, other: Union[float, np.ndarray, Symbol]) -> Addition: """return an :class:`Addition` object.""" return pybamm.add(self, other) - def __radd__(self, other: Symbol) -> Addition: + def __radd__(self, other: Union[float, np.ndarray, Symbol]) -> Addition: """return an :class:`Addition` object.""" return pybamm.add(other, self) - def __sub__(self, other: Symbol) -> Subtraction: + def __sub__(self, other: Union[float, np.ndarray, Symbol]) -> Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(self, other) - def __rsub__(self, other: Symbol) -> Subtraction: + def __rsub__(self, other: Union[float, np.ndarray, Symbol]) -> Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(other, self) - def __mul__(self, other: Symbol) -> Multiplication: + def __mul__(self, other: Union[float, np.ndarray, Symbol]) -> Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(self, other) - def __rmul__(self, other: Symbol) -> Multiplication: + def __rmul__(self, other: Union[float, np.ndarray, Symbol]) -> Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(other, self) - def __matmul__(self, other: Symbol) -> pybamm.MatrixMultiplication: + def __matmul__( + self, other: Union[float, np.ndarray, Symbol] + ) -> pybamm.MatrixMultiplication: """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(self, other) - def __rmatmul__(self, other: Symbol) -> pybamm.MatrixMultiplication: + def __rmatmul__( + self, other: Union[float, np.ndarray, Symbol] + ) -> pybamm.MatrixMultiplication: """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(other, self) - def __truediv__(self, other: Symbol) -> Division: + def __truediv__(self, other: Union[float, np.ndarray, Symbol]) -> Division: """return a :class:`Division` object.""" return pybamm.divide(self, other) - def __rtruediv__(self, other: Symbol) -> Division: + def __rtruediv__(self, other: Union[float, np.ndarray, Symbol]) -> Division: """return a :class:`Division` object.""" return pybamm.divide(other, self) - def __pow__(self, other: Union[Symbol, float]) -> pybamm.Power: + def __pow__(self, other: Union[float, np.ndarray, Symbol]) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(self, other) diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 6706022765..9d45c67405 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -301,21 +301,18 @@ def __init__(self, child, index, name=None, check_size=True): @classmethod def _from_json(cls, snippet: dict): """See :meth:`pybamm.UnaryOperator._from_json()`.""" - instance = cls.__new__(cls) - index = slice( snippet["index"]["start"], snippet["index"]["stop"], snippet["index"]["step"], ) - instance.__init__( + return cls( snippet["children"][0], index, name=snippet["name"], check_size=snippet["check_size"], ) - return instance def _unary_jac(self, child_jac): """See :meth:`pybamm.UnaryOperator._unary_jac()`.""" @@ -333,7 +330,14 @@ def _unary_jac(self, child_jac): def set_id(self): """See :meth:`pybamm.Symbol.set_id()`""" self._id = hash( - (self.__class__, self.name, self.slice.start, self.slice.stop, self.children[0].id, *tuple(self.domain)) + ( + self.__class__, + self.name, + self.slice.start, + self.slice.stop, + self.children[0].id, + *tuple(self.domain), + ) ) def _unary_evaluate(self, child): @@ -473,7 +477,9 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" - sympy_Divergence = have_optional_dependency("sympy.vector.operators", "Divergence") + sympy_Divergence = have_optional_dependency( + "sympy.vector.operators", "Divergence" + ) return sympy_Divergence(child) @@ -630,7 +636,18 @@ def integration_variable(self): def set_id(self): """See :meth:`pybamm.Symbol.set_id()`""" self._id = hash( - (self.__class__, self.name, *tuple([integration_variable.id for integration_variable in self.integration_variable]), self.children[0].id, *tuple(self.domain)) + ( + self.__class__, + self.name, + *tuple( + [ + integration_variable.id + for integration_variable in self.integration_variable + ] + ), + self.children[0].id, + *tuple(self.domain), + ) ) def _unary_new_copy(self, child): @@ -771,7 +788,13 @@ def __init__(self, child, vector_type="row"): def set_id(self): """See :meth:`pybamm.Symbol.set_id()`""" self._id = hash( - (self.__class__, self.name, self.vector_type, self.children[0].id, *tuple(self.domain)) + ( + self.__class__, + self.name, + self.vector_type, + self.children[0].id, + *tuple(self.domain), + ) ) def _unary_new_copy(self, child): @@ -865,7 +888,13 @@ def __init__(self, child, side, domain): def set_id(self): """See :meth:`pybamm.Symbol.set_id()`""" self._id = hash( - (self.__class__, self.name, self.side, self.children[0].id, *tuple([(k, tuple(v)) for k, v in self.domains.items()])) + ( + self.__class__, + self.name, + self.side, + self.children[0].id, + *tuple([(k, tuple(v)) for k, v in self.domains.items()]), + ) ) def _evaluates_on_edges(self, dimension: str) -> bool: @@ -923,7 +952,13 @@ def __init__(self, name, child, side): def set_id(self): """See :meth:`pybamm.Symbol.set_id()`""" self._id = hash( - (self.__class__, self.name, self.side, self.children[0].id, *tuple([(k, tuple(v)) for k, v in self.domains.items()])) + ( + self.__class__, + self.name, + self.side, + self.children[0].id, + *tuple([(k, tuple(v)) for k, v in self.domains.items()]), + ) ) def _unary_new_copy(self, child): diff --git a/pybamm/models/event.py b/pybamm/models/event.py index 850c550ae6..6214a0fb04 100644 --- a/pybamm/models/event.py +++ b/pybamm/models/event.py @@ -1,7 +1,7 @@ from enum import Enum import numpy as np -from typing import Optional +from typing import Optional, TypeVar, Type class EventType(Enum): @@ -27,6 +27,9 @@ class EventType(Enum): SWITCH = 3 +E = TypeVar("E", bound="Event") + + class Event: """ @@ -50,7 +53,7 @@ def __init__(self, name, expression, event_type=EventType.TERMINATION): self._event_type = event_type @classmethod - def _from_json(cls, snippet: dict): + def _from_json(cls: Type[E], snippet: dict) -> E: """ Reconstructs an Event instance during deserialisation of a JSON file. @@ -61,16 +64,12 @@ def _from_json(cls, snippet: dict): Should contain "name", "expression" and "event_type". """ - instance = cls.__new__(cls) - - instance.__init__( + return cls( snippet["name"], snippet["expression"], event_type=EventType(snippet["event_type"][1]), ) - return instance - def evaluate( self, t: Optional[float] = None, From 1dfb1b45ad59ce7c39fa8c92a81d64e57a4472c3 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 14 Dec 2023 16:10:32 +0000 Subject: [PATCH 09/32] Misc type hinting fixes for mypy (30 remaining) --- mypy.ini | 6 +++ pybamm/citations.py | 1 + pybamm/expression_tree/array.py | 5 ++- pybamm/expression_tree/binary_operators.py | 2 +- pybamm/expression_tree/broadcasts.py | 4 ++ pybamm/expression_tree/concatenations.py | 21 +++++---- pybamm/expression_tree/functions.py | 32 +++++++------- .../expression_tree/independent_variable.py | 6 +-- pybamm/expression_tree/interpolant.py | 4 +- .../expression_tree/operations/serialise.py | 2 +- pybamm/expression_tree/parameter.py | 44 ++++++++++++------- pybamm/expression_tree/symbol.py | 10 ++--- pybamm/expression_tree/unary_operators.py | 6 +-- pybamm/hints.py | 4 ++ pybamm/models/base_model.py | 7 +-- .../full_battery_models/base_battery_model.py | 3 +- 16 files changed, 88 insertions(+), 69 deletions(-) create mode 100644 pybamm/hints.py diff --git a/mypy.ini b/mypy.ini index 13e0cad69c..8c11b00049 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,3 +32,9 @@ ignore_missing_imports=True [mypy-absl.*] ignore_missing_imports=True + +[mypy-pybamm.models.base_model.*] +disable_error_code = attr-defined + +[mypy-pybamm.models.full_battery_models.base_battery_model.*] +disable_error_code = attr-defined \ No newline at end of file diff --git a/pybamm/citations.py b/pybamm/citations.py index e73351a4c6..167961a932 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -61,6 +61,7 @@ def _reset(self): self.register("Sulzer2021") self.register("Harris2020") + @staticmethod def _caller_name(): """ Returns the qualified name of classes that call :meth:`register` internally. diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index a3001507d1..ae745a887c 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -4,11 +4,14 @@ from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, issparse -from typing import Union, Tuple, Optional, Any +from typing import Union, Tuple, Optional, Any, TYPE_CHECKING import pybamm from pybamm.util import have_optional_dependency +if TYPE_CHECKING: + import sympy + class Array(pybamm.Symbol): """ diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 49df3300b2..d9b7c1c676 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -1482,7 +1482,7 @@ def sigmoid( def source( left: Union[numbers.Number, pybamm.Symbol], - right: Union[numbers.Number, pybamm.Symbol], + right: pybamm.Symbol, boundary=False, ): """ diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 9206812567..335bef44d9 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -58,6 +58,10 @@ def _diff(self, variable): # Differentiate the child and broadcast the result in the same way return self._unary_new_copy(self.child.diff(variable)) + def reduce_one_dimension(self): + """Reduce the broadcast by one dimension.""" + raise NotImplementedError + def to_json(self): raise NotImplementedError( "pybamm.Broadcast: Serialisation is only implemented for discretised models" diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index a567b3097d..26b44c57ae 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -7,15 +7,13 @@ import numpy as np from scipy.sparse import issparse, vstack -from typing import Optional, Sequence, Type, Union, TYPE_CHECKING -from typing_extensions import TypeGuard, TypeVar +from typing import Optional, Sequence, Type, Union +from typing_extensions import TypeGuard +from pybamm.hints import S import pybamm from pybamm.util import have_optional_dependency -if TYPE_CHECKING: - S = TypeVar("S", bound=pybamm.Symbol) # type: ignore[no-redef] - class Concatenation(pybamm.Symbol): """ @@ -29,7 +27,7 @@ class Concatenation(pybamm.Symbol): def __init__( self, - *children: Sequence[pybamm.Symbol], + *children: pybamm.Symbol, name: Optional[str] = None, check_domain=True, concat_fun=None, @@ -200,7 +198,7 @@ class NumpyConcatenation(Concatenation): The equations to concatenate """ - def __init__(self, *children: Sequence[pybamm.Symbol]): + def __init__(self, *children: pybamm.Symbol): children = list(children) # Turn objects that evaluate to scalars to objects that evaluate to vectors, # so that we can concatenate them @@ -571,17 +569,18 @@ def simplified_domain_concatenation( # Simplify Concatenation of StateVectors to a single StateVector # The sum of the evalation arrays of the StateVectors must be exactly 1 if all(isinstance(child, pybamm.StateVector) for child in children): - longest_eval_array = len(children[-1]._evaluation_array) + sv_children: list[pybamm.StateVector] = children # type narrow for mypy + longest_eval_array = len(sv_children[-1]._evaluation_array) eval_arrays = {} - for child in children: + for child in sv_children: eval_arrays[child] = np.concatenate( [ child.evaluation_array, np.zeros(longest_eval_array - len(child.evaluation_array)), ] ) - first_start = children[0].y_slices[0].start - last_stop = children[-1].y_slices[-1].stop + first_start = sv_children[0].y_slices[0].start + last_stop = sv_children[-1].y_slices[-1].stop if all( sum(array for array in eval_arrays.values())[first_start:last_stop] == 1 ): diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 31acf31c94..3a5bac0d97 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -235,22 +235,6 @@ def _from_json(cls, snippet): ) -def simplified_function(func_class: Type[SF], child: pybamm.Symbol): - """ - Simplifications implemented before applying the function. - Currently only implemented for one-child functions. - """ - if isinstance(child, pybamm.Broadcast): - # Move the function inside the broadcast - # Apply recursively - func_child_not_broad = pybamm.simplify_if_constant( - simplified_function(func_class, child.orphans[0]) - ) - return child._unary_new_copy(func_child_not_broad) - else: - return pybamm.simplify_if_constant(func_class(child)) - - class SpecificFunction(Function): """ Parent class for the specific functions, which implement their own `diff` @@ -316,6 +300,22 @@ def to_json(self): SF = TypeVar("SF", bound=SpecificFunction) +def simplified_function(func_class: Type[SF], child: pybamm.Symbol): + """ + Simplifications implemented before applying the function. + Currently only implemented for one-child functions. + """ + if isinstance(child, pybamm.Broadcast): + # Move the function inside the broadcast + # Apply recursively + func_child_not_broad = pybamm.simplify_if_constant( + simplified_function(func_class, child.orphans[0]) + ) + return child._unary_new_copy(func_child_not_broad) + else: + return pybamm.simplify_if_constant(func_class(child)) + + class Arcsinh(SpecificFunction): """Arcsinh function.""" diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 6f8d9315a3..270fcf61e3 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -74,11 +74,7 @@ def __init__(self): @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - - instance.__init__() - - return instance + return cls() def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 9664462c35..9e847c23fc 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -43,7 +43,7 @@ class Interpolant(pybamm.Function): def __init__( self, - x: Sequence[np.ndarray], + x: Union[np.ndarray, Sequence[np.ndarray]], y: np.ndarray, children: Union[Sequence[pybamm.Symbol], pybamm.Time], name: Optional[str] = None, @@ -103,7 +103,7 @@ def __init__( x1 = x[0] else: x1 = x - x = [x] + x: list[np.ndarray] = [x] x2 = None if x1.shape[0] != y.shape[0]: raise ValueError( diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index c7768217a3..4bd5e346df 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -114,7 +114,7 @@ def save_model( "pybamm_version": pybamm.__version__, "name": model.name, "options": model.options, - "bounds": [bound.tolist() for bound in model.bounds], + "bounds": [bound.tolist() for bound in model.bounds], # type: ignore[attr-defined] "concatenated_rhs": self._SymbolEncoder().default(model._concatenated_rhs), "concatenated_algebraic": self._SymbolEncoder().default( model._concatenated_algebraic diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index c8fa99cda8..5e21eb6c85 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -8,8 +8,8 @@ import numpy as np from typing import Optional, TYPE_CHECKING, Literal -# if TYPE_CHECKING: -# from pybamm import FunctionParameter +if TYPE_CHECKING: + import sympy import pybamm from pybamm.util import have_optional_dependency @@ -121,20 +121,24 @@ def __init__( self.print_name = print_name else: frame = sys._getframe().f_back - print_name = frame.f_code.co_name - if print_name.startswith("_"): - self.print_name = None - else: - try: - parent_param = frame.f_locals["self"] - except KeyError: - parent_param = None - if hasattr(parent_param, "domain") and parent_param.domain is not None: - # add "_n" or "_s" or "_p" if this comes from a Parameter class with - # a domain - d = parent_param.domain[0] - print_name += f"_{d}" - self.print_name = print_name + if frame is not None: + print_name = frame.f_code.co_name + if print_name.startswith("_"): + self.print_name = None + else: + try: + parent_param = frame.f_locals["self"] + except KeyError: + parent_param = None + if ( + hasattr(parent_param, "domain") + and parent_param.domain is not None + ): + # add "_n" or "_s" or "_p" if this comes from a Parameter class with + # a domain + d = parent_param.domain[0] + print_name += f"_{d}" + self.print_name = print_name def print_input_names(self): if self._input_names: @@ -168,7 +172,13 @@ def input_names(self, inp=None): def set_id(self): """See :meth:`pybamm.Symbol.set_id`""" self._id = hash( - (self.__class__, self.name, self.diff_variable, *tuple([child.id for child in self.children]), *tuple(self.domain)) + ( + self.__class__, + self.name, + self.diff_variable, + *tuple([child.id for child in self.children]), + *tuple(self.domain), + ) ) def diff(self, variable: pybamm.Symbol) -> pybamm.FunctionParameter: diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index ac179b6cdc..e0931e0886 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -15,7 +15,6 @@ Sequence, no_type_check, ) -from typing_extensions import TypeVar import pybamm from pybamm.util import have_optional_dependency @@ -29,8 +28,7 @@ Division, ) import casadi - - S = TypeVar("S", bound=pybamm.Symbol) + from hints import S DOMAIN_LEVELS = ["primary", "secondary", "tertiary", "quaternary"] EMPTY_DOMAINS: dict[str, list] = {k: [] for k in DOMAIN_LEVELS} @@ -164,7 +162,7 @@ def is_matrix_minus_one(expr: Symbol): return is_matrix_x(expr, -1) -def simplify_if_constant(symbol: Union[S, float]) -> S: +def simplify_if_constant(symbol: S) -> S: """ Utility function to simplify an expression tree if it evalutes to a constant scalar, vector or matrix @@ -930,7 +928,9 @@ def _evaluates_on_edges(self, dimension): # Default behaviour: return False return False - def has_symbol_of_classes(self, symbol_classes: Union[Symbol, tuple[type[Symbol]]]): + def has_symbol_of_classes( + self, symbol_classes: Union[type[Symbol], tuple[type[Symbol]]] + ): """ Returns True if equation has a term of the class(es) `symbol_class`. diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 9d45c67405..838a502c76 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -1018,11 +1018,7 @@ def __init__(self, children, initial_condition): @classmethod def _from_json(cls, snippet: dict): - instance = cls.__new__(cls) - - instance.__init__(snippet["children"][0], snippet["initial_condition"]) - - return instance + return cls(snippet["children"][0], snippet["initial_condition"]) def _unary_new_copy(self, child): return self.__class__(child, self.initial_condition) diff --git a/pybamm/hints.py b/pybamm/hints.py new file mode 100644 index 0000000000..696a6f4dee --- /dev/null +++ b/pybamm/hints.py @@ -0,0 +1,4 @@ +from typing_extensions import TypeVar +import pybamm + +S = TypeVar("S", bound=pybamm.Symbol) diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index c4ae414e7c..566716e011 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -13,6 +13,8 @@ from pybamm.expression_tree.operations.serialise import Serialise from pybamm.util import have_optional_dependency +from typing import Optional + class BaseModel: """ @@ -130,10 +132,9 @@ def deserialise(cls, properties: dict): """ Create a model instance from a serialised object. """ - instance = cls.__new__(cls) # append the model name with _saved to differentiate - instance.__init__(name=properties["name"] + "_saved") + instance = cls(name=properties["name"] + "_saved") instance.options = properties["options"] @@ -1218,7 +1219,7 @@ def save_model(self, filename=None, mesh=None, variables=None): Serialise().save_model(self, filename=filename, mesh=mesh, variables=variables) -def load_model(filename, battery_model: BaseModel = None): +def load_model(filename, battery_model: Optional[BaseModel] = None): """ Load in a saved model from a JSON file diff --git a/pybamm/models/full_battery_models/base_battery_model.py b/pybamm/models/full_battery_models/base_battery_model.py index b174ef581c..28138a315b 100644 --- a/pybamm/models/full_battery_models/base_battery_model.py +++ b/pybamm/models/full_battery_models/base_battery_model.py @@ -830,10 +830,9 @@ def deserialise(cls, properties: dict): """ Create a model instance from a serialised object. """ - instance = cls.__new__(cls) # append the model name with _saved to differentiate - instance.__init__( + instance = cls( options=properties["options"], name=properties["name"] + "_saved" ) From d9e3ae23c36485d91bec3fb79b8dadd9bc001a36 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Dec 2023 16:11:32 +0000 Subject: [PATCH 10/32] style: pre-commit fixes --- mypy.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index 8c11b00049..f0f32a1648 100644 --- a/mypy.ini +++ b/mypy.ini @@ -37,4 +37,4 @@ ignore_missing_imports=True disable_error_code = attr-defined [mypy-pybamm.models.full_battery_models.base_battery_model.*] -disable_error_code = attr-defined \ No newline at end of file +disable_error_code = attr-defined From 77f00a65866395dc69016f817c71c3f333c03248 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 15 Dec 2023 14:32:51 +0000 Subject: [PATCH 11/32] More misc edits to reduce mypy errors (19 remaining) --- pybamm/expression_tree/binary_operators.py | 12 ++++++------ pybamm/expression_tree/broadcasts.py | 2 +- pybamm/expression_tree/functions.py | 6 ++---- pybamm/expression_tree/interpolant.py | 4 ++-- pybamm/expression_tree/operations/evaluate_python.py | 2 +- pybamm/expression_tree/operations/jacobian.py | 2 +- pybamm/expression_tree/operations/serialise.py | 4 ++-- pybamm/expression_tree/symbol.py | 4 ++-- 8 files changed, 17 insertions(+), 19 deletions(-) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index d9b7c1c676..d2c7f2d5c6 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -1407,11 +1407,11 @@ def _heaviside(left: ChildValue, right: ChildValue, equal): """return a :class:`EqualHeaviside` object, or a smooth approximation.""" # Check for Concatenations and Broadcasts left, right = _simplify_elementwise_binary_broadcasts(left, right) - out = _simplified_binary_broadcast_concatenation( + concat_out = _simplified_binary_broadcast_concatenation( left, right, functools.partial(_heaviside, equal=equal) ) - if out is not None: - return out + if concat_out is not None: + return concat_out if ( left.is_constant() @@ -1434,9 +1434,9 @@ def _heaviside(left: ChildValue, right: ChildValue, equal): # (i.e. no need for smoothing) if k == "exact" or (left.is_constant() and right.is_constant()): if equal is True: - out = pybamm.EqualHeaviside(left, right) + out: pybamm.EqualHeaviside = pybamm.EqualHeaviside(left, right) else: - out = pybamm.NotEqualHeaviside(left, right) + out: pybamm.NotEqualHeaviside = pybamm.NotEqualHeaviside(left, right) # type: ignore[no-redef] else: out = pybamm.sigmoid(left, right, k) return pybamm.simplify_if_constant(out) @@ -1446,7 +1446,7 @@ def softminus( left: pybamm.Symbol, right: pybamm.Symbol, k: float, -) -> pybamm.Symbol: +): """ Softplus approximation to the minimum function. k is the smoothing parameter, set by `pybamm.settings.min_smoothing`. The recommended value is k=10. diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 335bef44d9..c99369eaff 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -469,7 +469,7 @@ def __init__( ): # Convert child to scalar if it is a number if isinstance(child, numbers.Number): - child: pybamm.Scalar = pybamm.Scalar(child) + child: pybamm.Scalar = pybamm.Scalar(child) # type: ignore[no-redef] if isinstance(auxiliary_domains, str): auxiliary_domains = {"secondary": auxiliary_domains} diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 3a5bac0d97..0945a634f6 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -87,11 +87,9 @@ def diff(self, variable: pybamm.Symbol): # remove None entries partial_derivatives = [x for x in partial_derivatives if x is not None] - derivative: pybamm.Symbol = sum( - partial_derivatives - ) # type:ignore[assignment] + derivative = sum(partial_derivatives) # type: ignore[arg-type] if derivative == 0: - derivative = pybamm.Scalar(0) + derivative = pybamm.Scalar(0) # type: ignore[assignment] return derivative diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 9e847c23fc..163cbacb07 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -103,7 +103,7 @@ def __init__( x1 = x[0] else: x1 = x - x: list[np.ndarray] = [x] + x: list[np.ndarray] = [x] # type: ignore[no-redef] x2 = None if x1.shape[0] != y.shape[0]: raise ValueError( @@ -131,7 +131,7 @@ def __init__( if extrapolate is False: fill_value = np.nan elif extrapolate is True: - fill_value = "extrapolate" # ignore: assignment + fill_value = "extrapolate" # type: ignore[assignment] interpolating_function = interpolate.interp1d( x1, y.T, diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 4bd223020a..dfcf297b32 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -62,7 +62,7 @@ def toarray(self): result = jax.numpy.zeros(self.shape, dtype=self.data.dtype) return result.at[self.row, self.col].add(self.data) - def dot_product(self, b: jaxlib.xla_extension.DeviceArray): + def dot_product(self, b): """ dot product of matrix with a dense column vector b diff --git a/pybamm/expression_tree/operations/jacobian.py b/pybamm/expression_tree/operations/jacobian.py index 57d603be94..14211c0b11 100644 --- a/pybamm/expression_tree/operations/jacobian.py +++ b/pybamm/expression_tree/operations/jacobian.py @@ -74,7 +74,7 @@ def _jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol): jac = symbol._unary_jac(child_jac) elif isinstance(symbol, pybamm.Function): - children_jacs = [None] * len(symbol.children) + children_jacs: list[None, pybamm.Symbol] = [None] * len(symbol.children) for i, child in enumerate(symbol.children): children_jacs[i] = self.jac(child, variable) # _function_jac defined in function class diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index 4bd5e346df..4e79685325 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -48,7 +48,7 @@ def default(self, node: dict): class _MeshEncoder(json.JSONEncoder): """Converts PyBaMM meshes into a JSON-serialisable format""" - def default(self, node: dict): + def default(self, node: pybamm.Mesh): node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} if isinstance(node, pybamm.Mesh): node_dict.update(node.to_json()) @@ -249,7 +249,7 @@ def _get_pybamm_class(self, snippet: dict): empty_class.__class__ = class_ except TypeError: # Mesh objects have a different layouts - empty_class = self._EmptyDict() + empty_class = self._EmptyDict() # type: ignore[assignment] empty_class.__class__ = class_ return empty_class diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index e0931e0886..af9e264f97 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -459,9 +459,9 @@ def scale(self): def reference(self): return self._reference - def __eq__(self, other: Union[Symbol, float]): + def __eq__(self, other): try: - return self._id == other._id # type:ignore + return self._id == other._id except AttributeError: if isinstance(other, numbers.Number): return self._id == pybamm.Scalar(other)._id From 69f1124acdf96a629a720238af9d749449ac678b Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 15 Dec 2023 14:40:14 +0000 Subject: [PATCH 12/32] Fix pre-commit issues --- .pre-commit-config.yaml | 7 ------- pybamm/install_odes.py | 3 +-- pybamm/parameters/bpx.py | 14 +++++++------- pybamm/parameters/parameter_values.py | 4 ++-- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9d4e68762b..ccd6fd2c84 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,13 +10,6 @@ repos: args: [--fix, --show-fixes, --select=I002] types_or: [python, pyi, jupyter] - - repo: https://github.com/nbQA-dev/nbQA - rev: 1.7.0 - hooks: - - id: nbqa-ruff - additional_dependencies: [ruff==0.0.284] - args: ["--fix","--ignore=E501,E402"] - - repo: https://github.com/adamchainz/blacken-docs rev: "1.16.0" hooks: diff --git a/pybamm/install_odes.py b/pybamm/install_odes.py index d477cc01ad..dbc9e37f30 100644 --- a/pybamm/install_odes.py +++ b/pybamm/install_odes.py @@ -10,7 +10,7 @@ try: # wget module is required to download SUNDIALS or SuiteSparse. - import wget # type: ignore + import wget NO_WGET = False except ModuleNotFoundError: @@ -108,7 +108,6 @@ def update_LD_LIBRARY_PATH(install_dir): def main(arguments=None): - log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" logger = logging.getLogger("scikits.odes setup") diff --git a/pybamm/parameters/bpx.py b/pybamm/parameters/bpx.py index 0d6f3612b3..1b89aa4afe 100644 --- a/pybamm/parameters/bpx.py +++ b/pybamm/parameters/bpx.py @@ -1,4 +1,4 @@ -from bpx import BPX, Function, InterpolatedTable # type: ignore +from bpx import BPX, Function, InterpolatedTable import pybamm import math from dataclasses import dataclass @@ -261,12 +261,12 @@ def _positive_electrode_entropic_change(sto, c_s_max): "Maximum concentration in " + negative_electrode.pre_name.lower() + "[mol.m-3]" ] k_n_norm = pybamm_dict[ - negative_electrode.pre_name - + "reaction rate constant [mol.m-2.s-1]" + negative_electrode.pre_name + "reaction rate constant [mol.m-2.s-1]" ] Ea_k_n = pybamm_dict.get( negative_electrode.pre_name - + "reaction rate constant activation energy [J.mol-1]", 0.0 + + "reaction rate constant activation energy [J.mol-1]", + 0.0, ) # Note that in BPX j = 2*F*k_norm*sqrt((ce/ce0)*(c/c_max)*(1-c/c_max))*sinh(...), # and in PyBaMM j = 2*k*sqrt(ce*c*(c_max - c))*sinh(...) @@ -292,12 +292,12 @@ def _negative_electrode_exchange_current_density(c_e, c_s_surf, c_s_max, T): "Maximum concentration in " + positive_electrode.pre_name.lower() + "[mol.m-3]" ] k_p_norm = pybamm_dict[ - positive_electrode.pre_name - + "reaction rate constant [mol.m-2.s-1]" + positive_electrode.pre_name + "reaction rate constant [mol.m-2.s-1]" ] Ea_k_p = pybamm_dict.get( positive_electrode.pre_name - + "reaction rate constant activation energy [J.mol-1]", 0.0 + + "reaction rate constant activation energy [J.mol-1]", + 0.0, ) # Note that in BPX j = 2*F*k_norm*sqrt((ce/ce0)*(c/c_max)*(1-c/c_max))*sinh(...), # and in PyBaMM j = 2*k*sqrt(ce*c*(c_max - c))*sinh(...) diff --git a/pybamm/parameters/parameter_values.py b/pybamm/parameters/parameter_values.py index 9f739878a7..d5f12f362f 100644 --- a/pybamm/parameters/parameter_values.py +++ b/pybamm/parameters/parameter_values.py @@ -94,8 +94,8 @@ def create_from_bpx(filename, target_soc=1): if target_soc < 0 or target_soc > 1: raise ValueError("Target SOC should be between 0 and 1") - from bpx import parse_bpx_file, get_electrode_concentrations # type: ignore - from .bpx import _bpx_to_param_dict # type: ignore + from bpx import parse_bpx_file, get_electrode_concentrations + from .bpx import _bpx_to_param_dict # parse bpx bpx = parse_bpx_file(filename) From 44e5161c457bb791493e43d4b0921e0be1bb25cb Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 15 Dec 2023 15:17:00 +0000 Subject: [PATCH 13/32] Edit imports --- .pre-commit-config.yaml | 2 +- pybamm/expression_tree/array.py | 2 +- pybamm/expression_tree/binary_operators.py | 3 +-- pybamm/expression_tree/broadcasts.py | 7 ++++--- pybamm/expression_tree/independent_variable.py | 2 +- pybamm/expression_tree/input_parameter.py | 2 +- pybamm/expression_tree/matrix.py | 3 ++- pybamm/expression_tree/operations/evaluate_python.py | 1 + pybamm/expression_tree/operations/jacobian.py | 1 + pybamm/expression_tree/operations/unpack_symbols.py | 2 +- pybamm/expression_tree/symbol.py | 2 -- pybamm/expression_tree/unary_operators.py | 2 +- 12 files changed, 15 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ccd6fd2c84..ed837e6fdb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: rev: "v0.1.6" hooks: - id: ruff - args: [--fix, --show-fixes, --select=I002] + args: [--fix, --show-fixes] types_or: [python, pyi, jupyter] - repo: https://github.com/adamchainz/blacken-docs diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index ae745a887c..925f36846b 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -4,7 +4,7 @@ from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, issparse -from typing import Union, Tuple, Optional, Any, TYPE_CHECKING +from typing import Union, Tuple, Optional, TYPE_CHECKING import pybamm from pybamm.util import have_optional_dependency diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index d2c7f2d5c6..57c70cb9a6 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -11,8 +11,7 @@ import pybamm from pybamm.util import have_optional_dependency -from typing import Union, Tuple, Optional, Callable, overload -from typing_extensions import TypeVar +from typing import Union, Tuple, Optional, Callable # create type alias(s) ChildValue = Union[float, np.ndarray, pybamm.Symbol] diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index c99369eaff..e78c1d07d7 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -1,16 +1,17 @@ # # Unary operator classes and methods # +from __future__ import annotations import numbers import numpy as np from scipy.sparse import csr_matrix -from typing import Sequence, Optional, Union, Type, SupportsFloat - -NumberType = Type[SupportsFloat] +from typing import Optional, Union, Type, SupportsFloat import pybamm +NumberType = Type[SupportsFloat] + class Broadcast(pybamm.SpatialOperator): """ diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 270fcf61e3..b898986c1c 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -1,8 +1,8 @@ # # IndependentVariable class # +from __future__ import annotations import sympy -import numpy as np from typing import Union, Optional, Any import pybamm diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index 1fcddd7744..b2affc1242 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -7,7 +7,7 @@ import scipy.sparse import pybamm -from typing import Union, Iterable, Optional, Sequence +from typing import Union, Optional class InputParameter(pybamm.Symbol): diff --git a/pybamm/expression_tree/matrix.py b/pybamm/expression_tree/matrix.py index ab9c28c392..a65b28fc39 100644 --- a/pybamm/expression_tree/matrix.py +++ b/pybamm/expression_tree/matrix.py @@ -1,9 +1,10 @@ # # Matrix class # +from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, issparse -from typing import Union, Optional, Type +from typing import Union, Optional import pybamm diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index dfcf297b32..c74345b52f 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -1,6 +1,7 @@ # # Write a symbol to python # +from __future__ import annotations import numbers from collections import OrderedDict from numpy.typing import ArrayLike diff --git a/pybamm/expression_tree/operations/jacobian.py b/pybamm/expression_tree/operations/jacobian.py index 14211c0b11..c536687ed2 100644 --- a/pybamm/expression_tree/operations/jacobian.py +++ b/pybamm/expression_tree/operations/jacobian.py @@ -1,6 +1,7 @@ # # Calculate the Jacobian of a symbol # +from __future__ import annotations from typing import Optional import pybamm diff --git a/pybamm/expression_tree/operations/unpack_symbols.py b/pybamm/expression_tree/operations/unpack_symbols.py index 73b23e1ed6..70c045f072 100644 --- a/pybamm/expression_tree/operations/unpack_symbols.py +++ b/pybamm/expression_tree/operations/unpack_symbols.py @@ -2,7 +2,7 @@ # Helper function to unpack a symbol # from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Iterable, Union, Sequence +from typing import TYPE_CHECKING, Optional, Union, Sequence if TYPE_CHECKING: import pybamm diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index af9e264f97..6dfe9825bf 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -11,9 +11,7 @@ Union, TYPE_CHECKING, Optional, - Iterable, Sequence, - no_type_check, ) import pybamm diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 838a502c76..aa9de2e242 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -3,7 +3,7 @@ # from __future__ import annotations import numbers -from typing import Optional, Union, Sequence +from typing import Optional, Union import numpy as np from scipy.sparse import csr_matrix, issparse From 811cf8e02a673ae797207f9c4d01c8b2d8d6c81d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Dec 2023 15:33:43 +0000 Subject: [PATCH 14/32] style: pre-commit fixes --- pybamm/expression_tree/operations/evaluate_python.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index c74345b52f..4e20d02b06 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -14,7 +14,6 @@ if pybamm.have_jax(): import jax - import jaxlib from jax.config import config platform = jax.lib.xla_bridge.get_backend().platform.casefold() From bb75a85420ced91cd15adf6da131b260072e8ffa Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 15 Dec 2023 15:45:39 +0000 Subject: [PATCH 15/32] Remove 'assert' and typing Tuples --- pybamm/expression_tree/array.py | 4 ++-- pybamm/expression_tree/binary_operators.py | 8 +++----- pybamm/expression_tree/operations/evaluate_python.py | 3 +-- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 925f36846b..23ae7eba42 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -4,7 +4,7 @@ from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, issparse -from typing import Union, Tuple, Optional, TYPE_CHECKING +from typing import Union, Optional, TYPE_CHECKING import pybamm from pybamm.util import have_optional_dependency @@ -190,7 +190,7 @@ def linspace(start: float, stop: float, num: int = 50, **kwargs) -> pybamm.Array def meshgrid( x: pybamm.Array, y: pybamm.Array, **kwargs -) -> Tuple[pybamm.Array, pybamm.Array]: +) -> tuple[pybamm.Array, pybamm.Array]: """ Return coordinate matrices as from coordinate vectors by calling `numpy.meshgrid` with keyword arguments 'kwargs'. For a list of 'kwargs' diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 57c70cb9a6..1c181d582e 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -11,7 +11,7 @@ import pybamm from pybamm.util import have_optional_dependency -from typing import Union, Tuple, Optional, Callable +from typing import Union, Optional, Callable # create type alias(s) ChildValue = Union[float, np.ndarray, pybamm.Symbol] @@ -19,7 +19,7 @@ def _preprocess_binary( left: ChildValue, right: ChildValue -) -> Tuple[pybamm.Symbol, pybamm.Symbol]: +) -> tuple[pybamm.Symbol, pybamm.Symbol]: if isinstance(left, numbers.Number): left = pybamm.Scalar(left) elif isinstance(left, np.ndarray): @@ -797,7 +797,7 @@ def _sympy_operator(self, left, right): def _simplify_elementwise_binary_broadcasts( left_child: ChildValue, right_child: ChildValue, -) -> Tuple[pybamm.Symbol, pybamm.Symbol]: +) -> tuple[pybamm.Symbol, pybamm.Symbol]: left, right = _preprocess_binary(left_child, right_child) def unpack_broadcast_recursive(symbol: pybamm.Symbol) -> pybamm.Symbol: @@ -1094,8 +1094,6 @@ def multiply( ): left, right = _simplify_elementwise_binary_broadcasts(left, right) - assert isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol) - # Move constant to always be on the left if right.is_constant() and not left.is_constant(): left, right = right, left diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 4e20d02b06..3017f1bf25 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -5,7 +5,6 @@ import numbers from collections import OrderedDict from numpy.typing import ArrayLike -from typing import Tuple import numpy as np import scipy.sparse @@ -391,7 +390,7 @@ def find_symbols( def to_python( symbol: pybamm.Symbol, debug=False, output_jax=False -) -> Tuple[OrderedDict, str]: +) -> tuple[OrderedDict, str]: """ This function converts an expression tree into a dict of constant input values, and valid python code that acts like the tree's :func:`pybamm.Symbol.evaluate` function From 07f48d9eb1f23e763364559f96f32140a1de8b87 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 15 Dec 2023 16:44:02 +0000 Subject: [PATCH 16/32] Update typing syntax to match 3.10+ style, add __future__ imports --- pybamm/expression_tree/array.py | 16 ++-- pybamm/expression_tree/averages.py | 11 ++- pybamm/expression_tree/binary_operators.py | 30 +++---- pybamm/expression_tree/broadcasts.py | 54 ++++++------- pybamm/expression_tree/concatenations.py | 16 ++-- pybamm/expression_tree/functions.py | 16 ++-- .../expression_tree/independent_variable.py | 22 +++--- pybamm/expression_tree/input_parameter.py | 14 ++-- pybamm/expression_tree/interpolant.py | 14 ++-- pybamm/expression_tree/matrix.py | 13 ++-- .../operations/convert_to_casadi.py | 5 +- pybamm/expression_tree/operations/jacobian.py | 5 +- pybamm/expression_tree/operations/latexify.py | 8 +- .../expression_tree/operations/serialise.py | 10 +-- .../operations/unpack_symbols.py | 8 +- pybamm/expression_tree/parameter.py | 4 +- pybamm/expression_tree/scalar.py | 6 +- pybamm/expression_tree/state_vector.py | 40 +++++----- pybamm/expression_tree/symbol.py | 78 +++++++++---------- pybamm/expression_tree/unary_operators.py | 18 ++--- pybamm/expression_tree/variable.py | 23 +++--- pybamm/expression_tree/vector.py | 13 ++-- pybamm/models/base_model.py | 6 +- pybamm/models/event.py | 12 +-- pybamm/simulation.py | 5 +- 25 files changed, 217 insertions(+), 230 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 23ae7eba42..201ad57648 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -4,7 +4,7 @@ from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, issparse -from typing import Union, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import pybamm from pybamm.util import have_optional_dependency @@ -41,12 +41,12 @@ class Array(pybamm.Symbol): def __init__( self, - entries: Union[np.ndarray, list, csr_matrix], - name: Optional[str] = None, - domain: Union[list[str], str, None] = None, - auxiliary_domains: Optional[dict[str, str]] = None, - domains: Optional[dict[str, list[str]]] = None, - entries_string: Optional[str] = None, + entries: np.ndarray | list | csr_matrix, + name: str | None = None, + domain: list[str] | str | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict[str, list[str]] | None = None, + entries_string: str | None = None, ) -> None: # if if isinstance(entries, list): @@ -101,7 +101,7 @@ def entries_string(self): return self._entries_string @entries_string.setter - def entries_string(self, value: Union[None, tuple]): + def entries_string(self, value: None | tuple): # We must include the entries in the hash, since different arrays can be # indistinguishable by class, name and domain alone # Slightly different syntax for sparse and non-sparse matrices diff --git a/pybamm/expression_tree/averages.py b/pybamm/expression_tree/averages.py index 53796e9642..7967cad51f 100644 --- a/pybamm/expression_tree/averages.py +++ b/pybamm/expression_tree/averages.py @@ -2,7 +2,7 @@ # Classes and methods for averaging # from __future__ import annotations -from typing import Union, Callable, Optional +from typing import Callable import pybamm @@ -20,9 +20,8 @@ def __init__( self, child: pybamm.Symbol, name: str, - integration_variable: Union[ - list[pybamm.IndependentVariable], pybamm.IndependentVariable - ], + integration_variable: list[pybamm.IndependentVariable] + | pybamm.IndependentVariable, ) -> None: super().__init__(child, integration_variable) self.name = name @@ -304,7 +303,7 @@ def r_average(symbol: pybamm.Symbol) -> pybamm.Symbol: def size_average( - symbol: pybamm.Symbol, f_a_dist: Optional[pybamm.Symbol] = None + symbol: pybamm.Symbol, f_a_dist: pybamm.Symbol | None = None ) -> pybamm.Symbol: """Convenience function for averaging over particle size R using the area-weighted particle-size distribution. @@ -359,7 +358,7 @@ def size_average( def _sum_of_averages( - symbol: Union[pybamm.Addition, pybamm.Subtraction], + symbol: pybamm.Addition | pybamm.Subtraction, average_function: Callable[[pybamm.Symbol], pybamm.Symbol], ): if isinstance(symbol, pybamm.Addition): diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 1c181d582e..19b9ee809e 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -11,7 +11,7 @@ import pybamm from pybamm.util import have_optional_dependency -from typing import Union, Optional, Callable +from typing import Union, Callable # create type alias(s) ChildValue = Union[float, np.ndarray, pybamm.Symbol] @@ -138,10 +138,10 @@ def _binary_new_copy(self, left: ChildValue, right: ChildValue): def evaluate( self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[Union[dict, str]] = None, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" left = self.left.evaluate(t, y, y_dot, inputs) @@ -239,8 +239,8 @@ def _binary_jac(self, left_jac, right_jac): def _binary_evaluate( self, - left: Union[float, np.ndarray, pybamm.Symbol], - right: Union[float, np.ndarray, pybamm.Symbol], + left: ChildValue, + right: ChildValue, ): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" # don't raise RuntimeWarning for NaNs @@ -265,15 +265,11 @@ def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) + self.right.diff(variable) - def _binary_jac( - self, left_jac: Union[float, np.ndarray], right_jac: Union[float, np.ndarray] - ): + def _binary_jac(self, left_jac: float | np.ndarray, right_jac: float | np.ndarray): """See :meth:`pybamm.BinaryOperator._binary_jac()`.""" return left_jac + right_jac - def _binary_evaluate( - self, left: Union[float, np.ndarray], right: Union[float, np.ndarray] - ): + def _binary_evaluate(self, left: float | np.ndarray, right: float | np.ndarray): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" return left + right @@ -300,9 +296,7 @@ def _binary_jac(self, left_jac, right_jac): """See :meth:`pybamm.BinaryOperator._binary_jac()`.""" return left_jac - right_jac - def _binary_evaluate( - self, left: Union[float, np.ndarray], right: Union[float, np.ndarray] - ): + def _binary_evaluate(self, left: float | np.ndarray, right: float | np.ndarray): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" return left - right @@ -829,7 +823,7 @@ def _simplified_binary_broadcast_concatenation( left: pybamm.Symbol, right: pybamm.Symbol, operator: Callable, -) -> Union[None, pybamm.Broadcast]: +) -> pybamm.Broadcast | None: """ Check if there are concatenations or broadcasts that we can commute the operator with @@ -1478,7 +1472,7 @@ def sigmoid( def source( - left: Union[numbers.Number, pybamm.Symbol], + left: numbers.Number | pybamm.Symbol, right: pybamm.Symbol, boundary=False, ): diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index e78c1d07d7..7d66e1a587 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -6,7 +6,7 @@ import numpy as np from scipy.sparse import csr_matrix -from typing import Optional, Union, Type, SupportsFloat +from typing import Type, SupportsFloat import pybamm @@ -37,7 +37,7 @@ def __init__( self, child: pybamm.Symbol, domains: dict[str, list[str]], - name: Optional[str] = None, + name: str | None = None, ): if name is None: name = "broadcast" @@ -96,9 +96,9 @@ class PrimaryBroadcast(Broadcast): def __init__( self, - child: Union[numbers.Number, pybamm.Symbol], - broadcast_domain: Union[str, list[str]], - name: Optional[str] = None, + child: numbers.Number | pybamm.Symbol, + broadcast_domain: list[str] | str, + name: str | None = None, ): # Convert child to scalar if it is a number if isinstance(child, numbers.Number): @@ -190,9 +190,9 @@ class PrimaryBroadcastToEdges(PrimaryBroadcast): def __init__( self, - child: Union[numbers.Number, pybamm.Symbol], - broadcast_domain: Union[str, list[str]], - name: Optional[str] = None, + child: numbers.Number | pybamm.Symbol, + broadcast_domain: list[str] | str, + name: str | None = None, ): name = name or "broadcast to edges" super().__init__(child, broadcast_domain, name) @@ -227,8 +227,8 @@ class SecondaryBroadcast(Broadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Union[str, list[str]], - name: Optional[str] = None, + broadcast_domain: list[str] | str, + name: str | None = None, ): # Convert domain to list if it's a string if isinstance(broadcast_domain, str): @@ -325,8 +325,8 @@ class SecondaryBroadcastToEdges(SecondaryBroadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Union[list[str], str], - name: Optional[str] = None, + broadcast_domain: list[str] | str, + name: str | None = None, ): name = name or "broadcast to edges" super().__init__(child, broadcast_domain, name) @@ -361,8 +361,8 @@ class TertiaryBroadcast(Broadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Union[list[str], str], - name: Optional[str] = None, + broadcast_domain: list[str] | str, + name: str | None = None, ): # Convert domain to list if it's a string if isinstance(broadcast_domain, str): @@ -374,7 +374,7 @@ def __init__( super().__init__(child, domains, name=name) def check_and_set_domains( - self, child: pybamm.Symbol, broadcast_domain: Union[list[str], str] + self, child: pybamm.Symbol, broadcast_domain: list[str] | str ): """See :meth:`Broadcast.check_and_set_domains`""" if child.domains["secondary"] == []: @@ -446,8 +446,8 @@ class TertiaryBroadcastToEdges(TertiaryBroadcast): def __init__( self, child: pybamm.Symbol, - broadcast_domain: Union[list[str], str], - name: Optional[str] = None, + broadcast_domain: list[str] | str, + name: str | None = None, ): name = name or "broadcast to edges" super().__init__(child, broadcast_domain, name) @@ -462,11 +462,11 @@ class FullBroadcast(Broadcast): def __init__( self, - child: Union[NumberType, pybamm.Symbol], - broadcast_domain: Optional[Union[list[str], str]] = None, - auxiliary_domains: Optional[Union[str, dict]] = None, - broadcast_domains: Optional[dict] = None, - name: Optional[str] = None, + child: NumberType | pybamm.Symbol, + broadcast_domain: list[str] | str | None = None, + auxiliary_domains: str | dict | None = None, + broadcast_domains: dict | None = None, + name: str | None = None, ): # Convert child to scalar if it is a number if isinstance(child, numbers.Number): @@ -534,11 +534,11 @@ class FullBroadcastToEdges(FullBroadcast): def __init__( self, - child: Union[NumberType, pybamm.Symbol], - broadcast_domain: Optional[Union[list[str], str]] = None, - auxiliary_domains: Optional[Union[str, dict]] = None, - broadcast_domains: Optional[dict] = None, - name: Optional[str] = None, + child: NumberType | pybamm.Symbol, + broadcast_domain: list[str] | str | None = None, + auxiliary_domains: str | dict | None = None, + broadcast_domains: dict | None = None, + name: str | None = None, ): name = name or "broadcast to edges" super().__init__( diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 26b44c57ae..64f8d898f1 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -7,7 +7,7 @@ import numpy as np from scipy.sparse import issparse, vstack -from typing import Optional, Sequence, Type, Union +from typing import Sequence, Type from typing_extensions import TypeGuard from pybamm.hints import S @@ -28,7 +28,7 @@ class Concatenation(pybamm.Symbol): def __init__( self, *children: pybamm.Symbol, - name: Optional[str] = None, + name: str | None = None, check_domain=True, concat_fun=None, ): @@ -122,10 +122,10 @@ def _concatenation_evaluate(self, children_eval: list[np.ndarray]): def evaluate( self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[Union[dict, str]] = None, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" children = self.children @@ -263,7 +263,7 @@ def __init__( self, children: Sequence[pybamm.Symbol], full_mesh: pybamm.Mesh, - copy_this: Optional[pybamm.DomainConcatenation] = None, + copy_this: pybamm.DomainConcatenation | None = None, ): # Convert any constant symbols in children to a Vector of the right size for # concatenation @@ -561,7 +561,7 @@ def numpy_concatenation(*children): def simplified_domain_concatenation( children: list[pybamm.Symbol], mesh: pybamm.Mesh, - copy_this: Optional[DomainConcatenation] = None, + copy_this: DomainConcatenation | None = None, ): """Perform simplifications on a domain concatenation.""" # Create the DomainConcatenation to read domain and child domain diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 0945a634f6..c4f69c9c9e 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -6,7 +6,7 @@ import numpy as np from scipy import special -from typing import Optional, Sequence, Callable, Type, Union +from typing import Sequence, Callable, Type from typing_extensions import TypeVar import pybamm @@ -36,9 +36,9 @@ def __init__( self, function: Callable, *children: pybamm.Symbol, - name: Optional[str] = None, - derivative: Optional[str] = "autograd", - differentiated_function: Optional[Callable] = None, + name: str | None = None, + derivative: str | None = "autograd", + differentiated_function: Callable | None = None, ): # Turn numbers into scalars children = list(children) @@ -146,10 +146,10 @@ def _function_jac(self, children_jacs): def evaluate( self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[Union[dict, str]] = None, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" evaluated_children = [ diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index b898986c1c..6cc2dadc9e 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -3,7 +3,7 @@ # from __future__ import annotations import sympy -from typing import Union, Optional, Any +from typing import Any import pybamm from pybamm.util import have_optional_dependency @@ -35,9 +35,9 @@ class IndependentVariable(pybamm.Symbol): def __init__( self, name: str, - domain: Optional[Union[list[str], str]] = None, - auxiliary_domains: Optional[dict] = None, - domains: Optional[dict] = None, + domain: list[str] | str | None = None, + auxiliary_domains: dict | None = None, + domains: dict | None = None, ) -> None: super().__init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains @@ -82,7 +82,7 @@ def create_copy(self): def _base_evaluate( self, - t: Optional[float] = None, + t: float | None = None, y: Any = None, y_dot: Any = None, inputs: Any = None, @@ -128,9 +128,9 @@ class SpatialVariable(IndependentVariable): def __init__( self, name: str, - domain: Optional[Union[list[str], str]] = None, - auxiliary_domains: Optional[dict] = None, - domains: Optional[dict] = None, + domain: list[str] | str | None = None, + auxiliary_domains: dict | None = None, + domains: dict | None = None, coord_sys=None, ) -> None: self.coord_sys = coord_sys @@ -191,9 +191,9 @@ class SpatialVariableEdge(SpatialVariable): def __init__( self, name: str, - domain: Union[list[str], str, None] = None, - auxiliary_domains: Optional[dict] = None, - domains: Optional[dict] = None, + domain: list[str] | str | None = None, + auxiliary_domains: dict | None = None, + domains: dict | None = None, coord_sys=None, ) -> None: super().__init__(name, domain, auxiliary_domains, domains, coord_sys) diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index b2affc1242..dc8f3b6d3c 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -7,8 +7,6 @@ import scipy.sparse import pybamm -from typing import Union, Optional - class InputParameter(pybamm.Symbol): """ @@ -31,8 +29,8 @@ class InputParameter(pybamm.Symbol): def __init__( self, name: str, - domain: Optional[Union[list[str], str]] = None, - expected_size: Optional[int] = None, + domain: list[str] | str | None = None, + expected_size: int | None = None, ) -> None: # Expected size defaults to 1 if no domain else None (gets set later) if expected_size is None: @@ -83,10 +81,10 @@ def _jac(self, variable: pybamm.StateVector) -> pybamm.Matrix: def _base_evaluate( self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[Union[dict, str]] = None, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, ): # inputs should be a dictionary # convert 'None' to empty dictionary for more informative error diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 163cbacb07..c8ba8f1baa 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -4,7 +4,7 @@ from __future__ import annotations import numpy as np from scipy import interpolate -from typing import Optional, Sequence, Union +from typing import Sequence import warnings import pybamm @@ -43,13 +43,13 @@ class Interpolant(pybamm.Function): def __init__( self, - x: Union[np.ndarray, Sequence[np.ndarray]], + x: np.ndarray | Sequence[np.ndarray], y: np.ndarray, - children: Union[Sequence[pybamm.Symbol], pybamm.Time], - name: Optional[str] = None, - interpolator: Optional[str] = "linear", - extrapolate: Optional[bool] = True, - entries_string: Optional[str] = None, + children: Sequence[pybamm.Symbol] | pybamm.Time, + name: str | None = None, + interpolator: str | None = "linear", + extrapolate: bool = True, + entries_string: str | None = None, ): # "cubic spline" has been renamed to "cubic" if interpolator == "cubic spline": diff --git a/pybamm/expression_tree/matrix.py b/pybamm/expression_tree/matrix.py index a65b28fc39..7abdc15934 100644 --- a/pybamm/expression_tree/matrix.py +++ b/pybamm/expression_tree/matrix.py @@ -4,7 +4,6 @@ from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, issparse -from typing import Union, Optional import pybamm @@ -16,12 +15,12 @@ class Matrix(pybamm.Array): def __init__( self, - entries: Union[np.ndarray, list, csr_matrix], - name: Optional[str] = None, - domain: Optional[list[str]] = None, - auxiliary_domains: Optional[dict[str, str]] = None, - domains: Optional[dict] = None, - entries_string: Optional[str] = None, + entries: np.ndarray | list | csr_matrix, + name: str | None = None, + domain: list[str] | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict | None = None, + entries_string: str | None = None, ) -> None: if isinstance(entries, list): entries = np.array(entries) diff --git a/pybamm/expression_tree/operations/convert_to_casadi.py b/pybamm/expression_tree/operations/convert_to_casadi.py index 3caf18bb4d..4adcef988e 100644 --- a/pybamm/expression_tree/operations/convert_to_casadi.py +++ b/pybamm/expression_tree/operations/convert_to_casadi.py @@ -1,11 +1,12 @@ # # Convert a PyBaMM expression tree to a CasADi expression tree # +from __future__ import annotations + import pybamm import casadi import numpy as np from scipy import special -from typing import Union class CasadiConverter(object): @@ -20,7 +21,7 @@ def convert( t: casadi.MX, y: casadi.MX, y_dot: casadi.MX, - inputs: Union[dict, None], + inputs: dict | None, ) -> casadi.MX: """ This function recurses down the tree, converting the PyBaMM expression tree to diff --git a/pybamm/expression_tree/operations/jacobian.py b/pybamm/expression_tree/operations/jacobian.py index c536687ed2..29521d8a14 100644 --- a/pybamm/expression_tree/operations/jacobian.py +++ b/pybamm/expression_tree/operations/jacobian.py @@ -2,7 +2,6 @@ # Calculate the Jacobian of a symbol # from __future__ import annotations -from typing import Optional import pybamm @@ -22,7 +21,7 @@ class Jacobian(object): def __init__( self, - known_jacs: Optional[dict[pybamm.Symbol, pybamm.Symbol]] = None, + known_jacs: dict[pybamm.Symbol, pybamm.Symbol] | None = None, clear_domain: bool = True, ): self._known_jacs = known_jacs or {} @@ -75,7 +74,7 @@ def _jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol): jac = symbol._unary_jac(child_jac) elif isinstance(symbol, pybamm.Function): - children_jacs: list[None, pybamm.Symbol] = [None] * len(symbol.children) + children_jacs: list[None | pybamm.Symbol] = [None] * len(symbol.children) for i, child in enumerate(symbol.children): children_jacs[i] = self.jac(child, variable) # _function_jac defined in function class diff --git a/pybamm/expression_tree/operations/latexify.py b/pybamm/expression_tree/operations/latexify.py index 2c9676087a..c16ab4b83d 100644 --- a/pybamm/expression_tree/operations/latexify.py +++ b/pybamm/expression_tree/operations/latexify.py @@ -1,12 +1,12 @@ # # Latexify class # +from __future__ import annotations + import copy import re import warnings -from typing import Optional - import pybamm from pybamm.expression_tree.printing.sympy_overrides import custom_print_func from pybamm.util import have_optional_dependency @@ -50,9 +50,7 @@ class Latexify: >>> model.latexify(newline=False)[1:5] """ - def __init__( - self, model, filename: Optional[str] = None, newline: Optional[bool] = True - ): + def __init__(self, model, filename: str | None = None, newline: bool = True): self.model = model self.filename = filename self.newline = newline diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index 4e79685325..8afadbdd0b 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -7,8 +7,6 @@ import numpy as np import re -from typing import Optional - class Serialise: """ @@ -80,9 +78,9 @@ class _EmptyDict(dict): def save_model( self, model: pybamm.BaseModel, - mesh: Optional[pybamm.Mesh] = None, - variables: Optional[pybamm.FuzzyDict] = None, - filename: Optional[str] = None, + mesh: pybamm.Mesh | None = None, + variables: pybamm.FuzzyDict | None = None, + filename: str | None = None, ): """Saves a discretised model to a JSON file. @@ -144,7 +142,7 @@ def save_model( json.dump(model_json, f) def load_model( - self, filename: str, battery_model: Optional[pybamm.BaseModel] = None + self, filename: str, battery_model: pybamm.BaseModel | None = None ) -> pybamm.BaseModel: """ Loads a discretised, ready to solve model into PyBaMM. diff --git a/pybamm/expression_tree/operations/unpack_symbols.py b/pybamm/expression_tree/operations/unpack_symbols.py index 70c045f072..c2fb10aad4 100644 --- a/pybamm/expression_tree/operations/unpack_symbols.py +++ b/pybamm/expression_tree/operations/unpack_symbols.py @@ -2,7 +2,7 @@ # Helper function to unpack a symbol # from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union, Sequence +from typing import TYPE_CHECKING, Sequence if TYPE_CHECKING: import pybamm @@ -23,8 +23,8 @@ class SymbolUnpacker(object): def __init__( self, - classes_to_find: Union[pybamm.Symbol, Sequence[pybamm.Symbol]], - unpacked_symbols: Optional[dict] = None, + classes_to_find: Sequence[pybamm.Symbol] | pybamm.Symbol, + unpacked_symbols: dict | None = None, ): self.classes_to_find = classes_to_find self._unpacked_symbols: dict = unpacked_symbols or {} @@ -53,7 +53,7 @@ def unpack_list_of_symbols( return all_instances def unpack_symbol( - self, symbol: Union[Sequence[pybamm.Symbol], pybamm.Symbol] + self, symbol: Sequence[pybamm.Symbol] | pybamm.Symbol ) -> list[pybamm.Symbol]: """ This function recurses down the tree, unpacking the symbols and saving the ones diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index 5e21eb6c85..fb0380cea4 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -6,7 +6,7 @@ import sys import numpy as np -from typing import Optional, TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: import sympy @@ -98,7 +98,7 @@ def __init__( self, name: str, inputs: dict[str, pybamm.Symbol], - diff_variable: Optional[pybamm.Symbol] = None, + diff_variable: pybamm.Symbol | None = None, print_name="calculate", ) -> None: # assign diff variable diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index d99b228f74..0c9d349484 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -4,7 +4,7 @@ from __future__ import annotations import numbers import numpy as np -from typing import Optional, Literal, Union, Any +from typing import Literal, Any import pybamm from pybamm.util import have_optional_dependency @@ -26,8 +26,8 @@ class Scalar(pybamm.Symbol): def __init__( self, - value: Union[float, numbers.Number, np.bool_], - name: Optional[str] = None, + value: float | numbers.Number | np.bool_, + name: str | None = None, ) -> None: # set default name if not provided self.value = value diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 8cd23b9f52..72ea09a776 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -4,7 +4,7 @@ from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, vstack -from typing import Optional, Union, Any +from typing import Any import pybamm @@ -38,11 +38,11 @@ def __init__( self, *y_slices: slice, base_name="y", - name: Optional[str] = None, - domain: Optional[Union[list[str], str]] = None, - auxiliary_domains: Optional[dict] = None, - domains: Optional[dict[str, list[str]]] = None, - evaluation_array: Optional[list] = None, + name: str | None = None, + domain: list[str] | str | None = None, + auxiliary_domains: dict | None = None, + domains: dict[str, list[str]] | None = None, + evaluation_array: list | None = None, ): for y_slice in y_slices: if not isinstance(y_slice, slice): @@ -262,11 +262,11 @@ class StateVector(StateVectorBase): def __init__( self, *y_slices: slice, - name: Optional[str] = None, - domain: Optional[Union[list[str], str]] = None, - auxiliary_domains: Optional[dict] = None, - domains: Optional[dict[str, list[str]]] = None, - evaluation_array: Optional[list] = None, + name: str | None = None, + domain: list[str] | str | None = None, + auxiliary_domains: dict | None = None, + domains: dict[str, list[str]] | None = None, + evaluation_array: list | None = None, ): super().__init__( *y_slices, @@ -281,7 +281,7 @@ def __init__( def _base_evaluate( self, t: Any = None, - y: Optional[np.ndarray] = None, + y: np.ndarray | None = None, y_dot: Any = None, inputs: Any = None, ): @@ -311,7 +311,7 @@ def diff(self, variable: pybamm.Symbol): else: return pybamm.Scalar(0) - def _jac(self, variable: Union[pybamm.StateVector, pybamm.StateVectorDot]): + def _jac(self, variable: pybamm.StateVector | pybamm.StateVectorDot): if isinstance(variable, pybamm.StateVector): return self._jac_same_vector(variable) elif isinstance(variable, pybamm.StateVectorDot): @@ -346,11 +346,11 @@ class StateVectorDot(StateVectorBase): def __init__( self, *y_slices: slice, - name: Optional[str] = None, - domain: Optional[Union[list[str], str]] = None, - auxiliary_domains: Optional[dict] = None, - domains: Optional[dict[str, list[str]]] = None, - evaluation_array: Optional[list] = None, + name: str | None = None, + domain: list[str] | str | None = None, + auxiliary_domains: dict | None = None, + domains: dict[str, list[str]] | None = None, + evaluation_array: list | None = None, ): super().__init__( *y_slices, @@ -366,7 +366,7 @@ def _base_evaluate( self, t: Any = None, y: Any = None, - y_dot: Optional[np.ndarray] = None, + y_dot: np.ndarray | None = None, inputs: Any = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" @@ -392,7 +392,7 @@ def diff(self, variable: pybamm.Symbol): else: return pybamm.Scalar(0) - def _jac(self, variable: Union[pybamm.StateVector, pybamm.StateVectorDot]): + def _jac(self, variable: pybamm.StateVector | pybamm.StateVectorDot): if isinstance(variable, pybamm.StateVectorDot): return self._jac_same_vector(variable) elif isinstance(variable, pybamm.StateVector): diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 6dfe9825bf..e242678bd1 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -8,9 +8,7 @@ from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property from typing import ( - Union, TYPE_CHECKING, - Optional, Sequence, ) @@ -32,7 +30,7 @@ EMPTY_DOMAINS: dict[str, list] = {k: [] for k in DOMAIN_LEVELS} -def domain_size(domain: Union[list[str], str]): +def domain_size(domain: list[str] | str): """ Get the domain size. @@ -218,10 +216,10 @@ class Symbol: def __init__( self, name: str, - children: Optional[Sequence[Symbol]] = None, - domain: Optional[Union[list[str], str]] = None, - auxiliary_domains: Optional[dict[str, str]] = None, - domains: Optional[dict[str, list[str]]] = None, + children: Sequence[Symbol] | None = None, + domain: list[str] | str | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict[str, list[str]] | None = None, ): super(Symbol, self).__init__() self.name = name @@ -408,9 +406,9 @@ def get_children_domains(self, children: Sequence[Symbol]): def read_domain_or_domains( self, - domain: Optional[Union[list[str], str]], - auxiliary_domains: Optional[dict[str, str]], - domains: Optional[dict], + domain: list[str] | str | None, + auxiliary_domains: dict[str, str] | None, + domains: dict | None, ): if domains is None: if isinstance(domain, str): @@ -581,51 +579,51 @@ def __repr__(self): {k: v for k, v in self.domains.items() if v != []}, ) - def __add__(self, other: Union[float, np.ndarray, Symbol]) -> Addition: + def __add__(self, other: Symbol | float | np.ndarray) -> Addition: """return an :class:`Addition` object.""" return pybamm.add(self, other) - def __radd__(self, other: Union[float, np.ndarray, Symbol]) -> Addition: + def __radd__(self, other: Symbol | float | np.ndarray) -> Addition: """return an :class:`Addition` object.""" return pybamm.add(other, self) - def __sub__(self, other: Union[float, np.ndarray, Symbol]) -> Subtraction: + def __sub__(self, other: Symbol | float | np.ndarray) -> Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(self, other) - def __rsub__(self, other: Union[float, np.ndarray, Symbol]) -> Subtraction: + def __rsub__(self, other: Symbol | float | np.ndarray) -> Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(other, self) - def __mul__(self, other: Union[float, np.ndarray, Symbol]) -> Multiplication: + def __mul__(self, other: Symbol | float | np.ndarray) -> Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(self, other) - def __rmul__(self, other: Union[float, np.ndarray, Symbol]) -> Multiplication: + def __rmul__(self, other: Symbol | float | np.ndarray) -> Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(other, self) def __matmul__( - self, other: Union[float, np.ndarray, Symbol] + self, other: Symbol | float | np.ndarray ) -> pybamm.MatrixMultiplication: """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(self, other) def __rmatmul__( - self, other: Union[float, np.ndarray, Symbol] + self, other: Symbol | float | np.ndarray ) -> pybamm.MatrixMultiplication: """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(other, self) - def __truediv__(self, other: Union[float, np.ndarray, Symbol]) -> Division: + def __truediv__(self, other: Symbol | float | np.ndarray) -> Division: """return a :class:`Division` object.""" return pybamm.divide(self, other) - def __rtruediv__(self, other: Union[float, np.ndarray, Symbol]) -> Division: + def __rtruediv__(self, other: Symbol | float | np.ndarray) -> Division: """return a :class:`Division` object.""" return pybamm.divide(other, self) - def __pow__(self, other: Union[float, np.ndarray, Symbol]) -> pybamm.Power: + def __pow__(self, other: Symbol | float | np.ndarray) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(self, other) @@ -633,7 +631,7 @@ def __rpow__(self, other: Symbol) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(other, self) - def __lt__(self, other: Union[Symbol, float]) -> pybamm.NotEqualHeaviside: + def __lt__(self, other: Symbol | float) -> pybamm.NotEqualHeaviside: """return a :class:`NotEqualHeaviside` object, or a smooth approximation.""" return pybamm.expression_tree.binary_operators._heaviside(self, other, False) @@ -734,7 +732,7 @@ def _diff(self, variable): def jac( self, variable: pybamm.Symbol, - known_jacs: Optional[dict[pybamm.Symbol, pybamm.Symbol]] = None, + known_jacs: dict[pybamm.Symbol, pybamm.Symbol] | None = None, clear_domain=True, ): """ @@ -759,10 +757,10 @@ def _jac(self, variable): def _base_evaluate( self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[Union[dict, str]] = None, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, ): """ evaluate expression tree. @@ -791,11 +789,11 @@ def _base_evaluate( def evaluate( self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[Union[dict, str]] = None, - ) -> Union[float, np.ndarray]: + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ) -> float | np.ndarray: """Evaluate expression tree (wrapper to allow using dict of known values). Parameters @@ -847,7 +845,7 @@ def is_constant(self): # Default behaviour is False return False - def evaluate_ignoring_errors(self, t: Union[None, float] = 0): # none + def evaluate_ignoring_errors(self, t: float | None = 0): # none """ Evaluates the expression. If a node exists in the tree that cannot be evaluated as a scalar or vector (e.g. Time, Parameter, Variable, StateVector), then None @@ -926,9 +924,7 @@ def _evaluates_on_edges(self, dimension): # Default behaviour: return False return False - def has_symbol_of_classes( - self, symbol_classes: Union[type[Symbol], tuple[type[Symbol]]] - ): + def has_symbol_of_classes(self, symbol_classes: tuple[type[Symbol]] | type[Symbol]): """ Returns True if equation has a term of the class(es) `symbol_class`. @@ -941,11 +937,11 @@ def has_symbol_of_classes( def to_casadi( self, - t: Optional[casadi.MX] = None, - y: Optional[casadi.MX] = None, - y_dot: Optional[casadi.MX] = None, - inputs: Optional[dict] = None, - casadi_symbols: Optional[Symbol] = None, + t: casadi.MX | None = None, + y: casadi.MX | None = None, + y_dot: casadi.MX | None = None, + inputs: dict | None = None, + casadi_symbols: Symbol | None = None, ): """ Convert the expression tree to a CasADi expression tree. diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index aa9de2e242..39f4f54ed4 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -3,7 +3,6 @@ # from __future__ import annotations import numbers -from typing import Optional, Union import numpy as np from scipy.sparse import csr_matrix, issparse @@ -26,7 +25,7 @@ class UnaryOperator(pybamm.Symbol): child node """ - def __init__(self, name: str, child: pybamm.Symbol, domains: Optional[dict] = None): + def __init__(self, name: str, child: pybamm.Symbol, domains: dict | None = None): if isinstance(child, numbers.Number): child = pybamm.Scalar(child) domains = domains or child.domains @@ -74,10 +73,10 @@ def _unary_evaluate(self, child): def evaluate( self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" child = self.child.evaluate(t, y, y_dot, inputs) @@ -396,7 +395,7 @@ class with a :class:`Matrix` child node """ - def __init__(self, name: str, child: pybamm.Symbol, domains: Optional[dict] = None): + def __init__(self, name: str, child: pybamm.Symbol, domains: dict | None = None): super().__init__(name, child, domains) def to_json(self): @@ -561,9 +560,8 @@ class Integral(SpatialOperator): def __init__( self, child, - integration_variable: Union[ - list[pybamm.IndependentVariable], pybamm.IndependentVariable - ], + integration_variable: list[pybamm.IndependentVariable] + | pybamm.IndependentVariable, ): if not isinstance(integration_variable, list): integration_variable = [integration_variable] diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index e57c6f3af2..22992c3b63 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -6,7 +6,6 @@ import numbers import pybamm from pybamm.util import have_optional_dependency -from typing import Union, Optional class VariableBase(pybamm.Symbol): @@ -51,13 +50,13 @@ class VariableBase(pybamm.Symbol): def __init__( self, name: str, - domain: Optional[Union[list[str], str]] = None, - auxiliary_domains: Optional[dict] = None, - domains: Optional[dict] = None, - bounds: Optional[tuple[pybamm.Symbol]] = None, - print_name: Optional[str] = None, - scale: Optional[Union[float, pybamm.Symbol]] = 1, - reference: Optional[Union[float, pybamm.Symbol]] = 0, + domain: list[str] | str | None = None, + auxiliary_domains: dict | None = None, + domains: dict | None = None, + bounds: tuple[pybamm.Symbol] | None = None, + print_name: str | None = None, + scale: float | pybamm.Symbol | None = 1, + reference: float | pybamm.Symbol | None = 0, ): if isinstance(scale, numbers.Number): scale = pybamm.Scalar(scale) @@ -104,7 +103,13 @@ def bounds(self, values: tuple[numbers.Number, numbers.Number]): def set_id(self): self._id = hash( - (self.__class__, self.name, self.scale, self.reference, *tuple([(k, tuple(v)) for k, v in self.domains.items() if v != []])) + ( + self.__class__, + self.name, + self.scale, + self.reference, + *tuple([(k, tuple(v)) for k, v in self.domains.items() if v != []]), + ) ) def create_copy(self): diff --git a/pybamm/expression_tree/vector.py b/pybamm/expression_tree/vector.py index 4070b97f6f..3d2981e808 100644 --- a/pybamm/expression_tree/vector.py +++ b/pybamm/expression_tree/vector.py @@ -3,7 +3,6 @@ # from __future__ import annotations import numpy as np -from typing import Union, Optional import pybamm @@ -15,12 +14,12 @@ class Vector(pybamm.Array): def __init__( self, - entries: Union[np.ndarray, list, np.matrix], - name: Optional[str] = None, - domain: Optional[Union[list[str], str]] = None, - auxiliary_domains: Optional[dict[str, str]] = None, - domains: Optional[dict] = None, - entries_string: Optional[str] = None, + entries: np.ndarray | list | np.matrix, + name: str | None = None, + domain: list[str] | str | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict | None = None, + entries_string: str | None = None, ) -> None: if isinstance(entries, (list, np.matrix)): entries = np.array(entries) diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 566716e011..3a92e69564 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -1,6 +1,8 @@ # # Base model class # +from __future__ import annotations + import numbers import warnings from collections import OrderedDict @@ -13,8 +15,6 @@ from pybamm.expression_tree.operations.serialise import Serialise from pybamm.util import have_optional_dependency -from typing import Optional - class BaseModel: """ @@ -1219,7 +1219,7 @@ def save_model(self, filename=None, mesh=None, variables=None): Serialise().save_model(self, filename=filename, mesh=mesh, variables=variables) -def load_model(filename, battery_model: Optional[BaseModel] = None): +def load_model(filename, battery_model: BaseModel | None = None): """ Load in a saved model from a JSON file diff --git a/pybamm/models/event.py b/pybamm/models/event.py index 6214a0fb04..ee4115157d 100644 --- a/pybamm/models/event.py +++ b/pybamm/models/event.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from enum import Enum import numpy as np -from typing import Optional, TypeVar, Type +from typing import TypeVar, Type class EventType(Enum): @@ -72,10 +74,10 @@ def _from_json(cls: Type[E], snippet: dict) -> E: def evaluate( self, - t: Optional[float] = None, - y: Optional[np.ndarray] = None, - y_dot: Optional[np.ndarray] = None, - inputs: Optional[dict] = None, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | None = None, ): """ Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate` diff --git a/pybamm/simulation.py b/pybamm/simulation.py index bf418f068d..678462dfc5 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -1,6 +1,8 @@ # # Simulation class # +from __future__ import annotations + import pickle import pybamm import numpy as np @@ -10,7 +12,6 @@ from functools import lru_cache from datetime import timedelta from pybamm.util import have_optional_dependency -from typing import Optional from pybamm.expression_tree.operations.serialise import Serialise @@ -1195,7 +1196,7 @@ def save(self, filename): def save_model( self, - filename: Optional[str] = None, + filename: str | None = None, mesh: bool = False, variables: bool = False, ): From e9a77a8ab6ecc3abc29ab23301adc1e05bac1a41 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 12 Jan 2024 14:15:41 +0000 Subject: [PATCH 17/32] Mypy passes with `mypy pybamm` command --- mypy.ini | 36 ++----------------- pybamm/expression_tree/broadcasts.py | 10 +++--- pybamm/expression_tree/concatenations.py | 7 ++-- pybamm/expression_tree/functions.py | 4 +-- pybamm/expression_tree/interpolant.py | 8 ++--- pybamm/expression_tree/operations/jacobian.py | 2 +- .../expression_tree/operations/serialise.py | 6 ++-- pybamm/expression_tree/symbol.py | 8 ++++- pybamm/expression_tree/unary_operators.py | 2 +- 9 files changed, 30 insertions(+), 53 deletions(-) diff --git a/mypy.ini b/mypy.ini index f0f32a1648..5c97998a83 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,37 +1,7 @@ [mypy] - -[mypy-scipy.*] -ignore_missing_imports=True - -[mypy-casadi.*] -ignore_missing_imports=True - -[mypy-matplotlib.*] -ignore_missing_imports=True - -[mypy-pandas.*] -ignore_missing_imports=True - -[mypy-pybtex.*] -ignore_missing_imports=True - -[mypy-ipywidgets.*] -ignore_missing_imports=True - -[mypy-anytree.*] -ignore_missing_imports=True - -[mypy-pkg_resources.*] -ignore_missing_imports=True - -[mypy-tqdm.*] -ignore_missing_imports=True - -[mypy-skfem.*] -ignore_missing_imports=True - -[mypy-absl.*] -ignore_missing_imports=True +ignore_missing_imports = True +allow_redefinition = True +disable_error_code = call-overload, operator [mypy-pybamm.models.base_model.*] disable_error_code = attr-defined diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 7d66e1a587..499b9a33b8 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -462,15 +462,17 @@ class FullBroadcast(Broadcast): def __init__( self, - child: NumberType | pybamm.Symbol, + child_input: NumberType | float | pybamm.Symbol, broadcast_domain: list[str] | str | None = None, auxiliary_domains: str | dict | None = None, broadcast_domains: dict | None = None, name: str | None = None, ): # Convert child to scalar if it is a number - if isinstance(child, numbers.Number): - child: pybamm.Scalar = pybamm.Scalar(child) # type: ignore[no-redef] + if isinstance(child_input, numbers.Number): + child: pybamm.Scalar = pybamm.Scalar(child_input) + else: + child: pybamm.Symbol = child_input # type: ignore[no-redef] if isinstance(auxiliary_domains, str): auxiliary_domains = {"secondary": auxiliary_domains} @@ -534,7 +536,7 @@ class FullBroadcastToEdges(FullBroadcast): def __init__( self, - child: NumberType | pybamm.Symbol, + child: NumberType | float | pybamm.Symbol, broadcast_domain: list[str] | str | None = None, auxiliary_domains: str | dict | None = None, broadcast_domains: dict | None = None, diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 64f8d898f1..ad2cdc5975 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -128,10 +128,7 @@ def evaluate( inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" - children = self.children - children_eval = [None] * len(children) - for idx, child in enumerate(children): - children_eval[idx] = child.evaluate(t, y, y_dot, inputs) + children_eval = [child.evaluate(t, y, y_dot, inputs) for child in self.children] return self._concatenation_evaluate(children_eval) def create_copy(self): @@ -569,7 +566,7 @@ def simplified_domain_concatenation( # Simplify Concatenation of StateVectors to a single StateVector # The sum of the evalation arrays of the StateVectors must be exactly 1 if all(isinstance(child, pybamm.StateVector) for child in children): - sv_children: list[pybamm.StateVector] = children # type narrow for mypy + sv_children: list[pybamm.StateVector] = children # type: ignore[assignment] longest_eval_array = len(sv_children[-1]._evaluation_array) eval_arrays = {} for child in sv_children: diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index c4f69c9c9e..f7683313cf 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -117,7 +117,7 @@ def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float): else: # keep using "derivative" as derivative return pybamm.Function( - self.function.derivative(), + self.function.derivative(), # type: ignore[attr-defined] *children, derivative="derivative", differentiated_function=self.function, @@ -311,7 +311,7 @@ def simplified_function(func_class: Type[SF], child: pybamm.Symbol): ) return child._unary_new_copy(func_child_not_broad) else: - return pybamm.simplify_if_constant(func_class(child)) + return pybamm.simplify_if_constant(func_class(child)) # type: ignore[call-arg, arg-type] class Arcsinh(SpecificFunction): diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index c8ba8f1baa..e57797d2ef 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -208,13 +208,13 @@ def __init__( def _from_json(cls, snippet: dict): """Create an Interpolant object from JSON data""" + x1 = [] + if len(snippet["x"]) == 1: - x = [np.array(x) for x in snippet["x"]] - else: - x = tuple(np.array(x) for x in snippet["x"]) + x1 = [np.array(x) for x in snippet["x"]] return cls( - x, + x1 if x1 else tuple(np.array(x) for x in snippet["x"]), np.array(snippet["y"]), snippet["children"], name=snippet["name"], diff --git a/pybamm/expression_tree/operations/jacobian.py b/pybamm/expression_tree/operations/jacobian.py index 29521d8a14..daacd087c8 100644 --- a/pybamm/expression_tree/operations/jacobian.py +++ b/pybamm/expression_tree/operations/jacobian.py @@ -69,7 +69,7 @@ def _jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol): jac = symbol._binary_jac(left_jac, right_jac) elif isinstance(symbol, pybamm.UnaryOperator): - child_jac = self.jac(symbol.child, variable) + child_jac = self.jac(symbol.child, variable) # type: ignore[has-type] # _unary_jac defined in derived classes for specific rules jac = symbol._unary_jac(child_jac) diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index 8afadbdd0b..d8807bf8c9 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -51,10 +51,12 @@ def default(self, node: pybamm.Mesh): if isinstance(node, pybamm.Mesh): node_dict.update(node.to_json()) - node_dict["sub_meshes"] = {} + submeshes = {} for k, v in node.items(): if len(k) == 1 and "ghost cell" not in k[0]: - node_dict["sub_meshes"][k[0]] = self.default(v) + submeshes[k[0]] = self.default(v) + + node_dict["sub_meshes"] = submeshes return node_dict diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index e242678bd1..c7c4f1b9ee 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -171,6 +171,10 @@ def simplify_if_constant(symbol: S) -> S: or (isinstance(result, np.ndarray) and result.ndim == 0) or isinstance(result, np.bool_) ): + if isinstance(result, np.ndarray): + # type-narrow for Scalar + new_result = result[0] + return pybamm.Scalar(new_result) return pybamm.Scalar(result) elif isinstance(result, np.ndarray) or issparse(result): if result.ndim == 1 or result.shape[1] == 1: @@ -924,7 +928,9 @@ def _evaluates_on_edges(self, dimension): # Default behaviour: return False return False - def has_symbol_of_classes(self, symbol_classes: tuple[type[Symbol]] | type[Symbol]): + def has_symbol_of_classes( + self, symbol_classes: tuple[type[Symbol], ...] | type[Symbol] + ): """ Returns True if equation has a term of the class(es) `symbol_class`. diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 39f4f54ed4..88ca86fad0 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -76,7 +76,7 @@ def evaluate( t: float | None = None, y: np.ndarray | None = None, y_dot: np.ndarray | None = None, - inputs: dict | None = None, + inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" child = self.child.evaluate(t, y, y_dot, inputs) From 4d6fb32e499eb3de55f4a2d7295a7f72b8b8ad2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jan 2024 11:18:10 +0000 Subject: [PATCH 18/32] style: pre-commit fixes --- pybamm/expression_tree/broadcasts.py | 4 ++-- pybamm/expression_tree/concatenations.py | 4 ++-- pybamm/expression_tree/functions.py | 4 ++-- pybamm/models/event.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 499b9a33b8..19943cbb17 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -592,9 +592,9 @@ def full_like(symbols: tuple[pybamm.Symbol, ...], fill_value: float) -> pybamm.S shape = sum_symbol.shape # use vector or matrix if shape[1] == 1: - array_type: Type[pybamm.Vector] = pybamm.Vector + array_type: type[pybamm.Vector] = pybamm.Vector else: - array_type: Type[pybamm.Matrix] = pybamm.Matrix # type:ignore[no-redef] + array_type: type[pybamm.Matrix] = pybamm.Matrix # type:ignore[no-redef] # return dense array, except for a matrix of zeros if shape[1] != 1 and fill_value == 0: entries = csr_matrix(shape) diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 7c93cdeee9..3b8eca9666 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -7,7 +7,7 @@ import numpy as np from scipy.sparse import issparse, vstack -from typing import Sequence, Type +from typing import Sequence from typing_extensions import TypeGuard from pybamm.hints import S @@ -596,6 +596,6 @@ def domain_concatenation(children: list[pybamm.Symbol], mesh: pybamm.Mesh): def all_children_are( children: list[pybamm.Symbol], - class_type: Type[S], + class_type: type[S], ) -> TypeGuard[list[S]]: return all(isinstance(child, class_type) for child in children) diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index dd47bfda24..1c4271c42d 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -6,7 +6,7 @@ import numpy as np from scipy import special -from typing import Sequence, Callable, Type +from typing import Sequence, Callable from typing_extensions import TypeVar import pybamm @@ -298,7 +298,7 @@ def to_json(self): SF = TypeVar("SF", bound=SpecificFunction) -def simplified_function(func_class: Type[SF], child: pybamm.Symbol): +def simplified_function(func_class: type[SF], child: pybamm.Symbol): """ Simplifications implemented before applying the function. Currently only implemented for one-child functions. diff --git a/pybamm/models/event.py b/pybamm/models/event.py index ee4115157d..8d2695160e 100644 --- a/pybamm/models/event.py +++ b/pybamm/models/event.py @@ -3,7 +3,7 @@ from enum import Enum import numpy as np -from typing import TypeVar, Type +from typing import TypeVar class EventType(Enum): @@ -55,7 +55,7 @@ def __init__(self, name, expression, event_type=EventType.TERMINATION): self._event_type = event_type @classmethod - def _from_json(cls: Type[E], snippet: dict) -> E: + def _from_json(cls: type[E], snippet: dict) -> E: """ Reconstructs an Event instance during deserialisation of a JSON file. From 64f3fc9757c70462dd4ea77540a53f5093a70360 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Tue, 16 Jan 2024 12:16:28 +0000 Subject: [PATCH 19/32] Mypy fixes after merge --- pybamm/experiment/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pybamm/experiment/experiment.py b/pybamm/experiment/experiment.py index b04281d78d..b66fe187be 100644 --- a/pybamm/experiment/experiment.py +++ b/pybamm/experiment/experiment.py @@ -73,7 +73,7 @@ def __init__( for cycle in operating_conditions: # Check types and convert to list if not isinstance(cycle, tuple): - cycle = (cycle,) + cycle = (cycle,) # type: ignore[assignment] operating_conditions_cycles.append(cycle) self.operating_conditions_cycles = operating_conditions_cycles From c9a0ff128c8d51d1b9a08f5029918d11849dfaf8 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Tue, 16 Jan 2024 17:04:43 +0000 Subject: [PATCH 20/32] Fix coverage issues --- pybamm/expression_tree/array.py | 2 +- pybamm/expression_tree/broadcasts.py | 2 +- pybamm/expression_tree/concatenations.py | 7 ---- .../operations/unpack_symbols.py | 2 +- pybamm/expression_tree/parameter.py | 2 +- pybamm/expression_tree/symbol.py | 33 +++++++------------ 6 files changed, 16 insertions(+), 32 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index a5f1314e01..a433dec822 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -9,7 +9,7 @@ import pybamm from pybamm.util import have_optional_dependency -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover import sympy diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 19943cbb17..8e22ecd487 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -59,7 +59,7 @@ def _diff(self, variable): # Differentiate the child and broadcast the result in the same way return self._unary_new_copy(self.child.diff(variable)) - def reduce_one_dimension(self): + def reduce_one_dimension(self): # pragma: no cover """Reduce the broadcast by one dimension.""" raise NotImplementedError diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 3b8eca9666..1a6d69d64d 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -592,10 +592,3 @@ def domain_concatenation(children: list[pybamm.Symbol], mesh: pybamm.Mesh): """Helper function to create domain concatenations.""" # TODO: add option to turn off simplifications return simplified_domain_concatenation(children, mesh) - - -def all_children_are( - children: list[pybamm.Symbol], - class_type: type[S], -) -> TypeGuard[list[S]]: - return all(isinstance(child, class_type) for child in children) diff --git a/pybamm/expression_tree/operations/unpack_symbols.py b/pybamm/expression_tree/operations/unpack_symbols.py index 56cff5e859..1933eada76 100644 --- a/pybamm/expression_tree/operations/unpack_symbols.py +++ b/pybamm/expression_tree/operations/unpack_symbols.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Sequence -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover import pybamm diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index fb0380cea4..911e0ca59f 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -8,7 +8,7 @@ import numpy as np from typing import TYPE_CHECKING, Literal -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover import sympy import pybamm diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index bbf67bf8d3..3320b361ad 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -7,22 +7,13 @@ import numpy as np from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property -from typing import ( - TYPE_CHECKING, - Sequence, -) +from typing import TYPE_CHECKING, Sequence, cast import pybamm from pybamm.util import have_optional_dependency from pybamm.expression_tree.printing.print_name import prettify_print_name -if TYPE_CHECKING: - from pybamm.expression_tree.binary_operators import ( - Addition, - Subtraction, - Multiplication, - Division, - ) +if TYPE_CHECKING: # pragma: no cover import casadi from hints import S @@ -171,9 +162,9 @@ def simplify_if_constant(symbol: S) -> S: or (isinstance(result, np.ndarray) and result.ndim == 0) or isinstance(result, np.bool_) ): - if isinstance(result, np.ndarray): + if isinstance(result, np.ndarray): # pragma: no cover # type-narrow for Scalar - new_result = result[0] + new_result = cast(float, result) return pybamm.Scalar(new_result) return pybamm.Scalar(result) elif isinstance(result, np.ndarray) or issparse(result): @@ -582,27 +573,27 @@ def __repr__(self): {k: v for k, v in self.domains.items() if v != []}, ) - def __add__(self, other: Symbol | float | np.ndarray) -> Addition: + def __add__(self, other: Symbol | float | np.ndarray) -> pybamm.Addition: """return an :class:`Addition` object.""" return pybamm.add(self, other) - def __radd__(self, other: Symbol | float | np.ndarray) -> Addition: + def __radd__(self, other: Symbol | float | np.ndarray) -> pybamm.Addition: """return an :class:`Addition` object.""" return pybamm.add(other, self) - def __sub__(self, other: Symbol | float | np.ndarray) -> Subtraction: + def __sub__(self, other: Symbol | float | np.ndarray) -> pybamm.Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(self, other) - def __rsub__(self, other: Symbol | float | np.ndarray) -> Subtraction: + def __rsub__(self, other: Symbol | float | np.ndarray) -> pybamm.Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(other, self) - def __mul__(self, other: Symbol | float | np.ndarray) -> Multiplication: + def __mul__(self, other: Symbol | float | np.ndarray) -> pybamm.Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(self, other) - def __rmul__(self, other: Symbol | float | np.ndarray) -> Multiplication: + def __rmul__(self, other: Symbol | float | np.ndarray) -> pybamm.Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(other, self) @@ -618,11 +609,11 @@ def __rmatmul__( """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(other, self) - def __truediv__(self, other: Symbol | float | np.ndarray) -> Division: + def __truediv__(self, other: Symbol | float | np.ndarray) -> pybamm.Division: """return a :class:`Division` object.""" return pybamm.divide(self, other) - def __rtruediv__(self, other: Symbol | float | np.ndarray) -> Division: + def __rtruediv__(self, other: Symbol | float | np.ndarray) -> pybamm.Division: """return a :class:`Division` object.""" return pybamm.divide(other, self) From df6501fc1b162fe6f9ee222e3c3c2165c451f4a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jan 2024 17:05:25 +0000 Subject: [PATCH 21/32] style: pre-commit fixes --- pybamm/expression_tree/concatenations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 1a6d69d64d..2d2ee055b3 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -8,8 +8,6 @@ import numpy as np from scipy.sparse import issparse, vstack from typing import Sequence -from typing_extensions import TypeGuard -from pybamm.hints import S import pybamm from pybamm.util import have_optional_dependency From ed2288b2b371c3e1778b92cbfc334feeae35698f Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 18 Jan 2024 15:38:04 +0000 Subject: [PATCH 22/32] Remove unnecessary 'hints.py' file --- pybamm/expression_tree/symbol.py | 3 +-- pybamm/hints.py | 4 ---- 2 files changed, 1 insertion(+), 6 deletions(-) delete mode 100644 pybamm/hints.py diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 3320b361ad..a9efed12e4 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -15,7 +15,6 @@ if TYPE_CHECKING: # pragma: no cover import casadi - from hints import S DOMAIN_LEVELS = ["primary", "secondary", "tertiary", "quaternary"] EMPTY_DOMAINS: dict[str, list] = {k: [] for k in DOMAIN_LEVELS} @@ -149,7 +148,7 @@ def is_matrix_minus_one(expr: Symbol): return is_matrix_x(expr, -1) -def simplify_if_constant(symbol: S) -> S: +def simplify_if_constant(symbol: pybamm.Symbol): """ Utility function to simplify an expression tree if it evalutes to a constant scalar, vector or matrix diff --git a/pybamm/hints.py b/pybamm/hints.py deleted file mode 100644 index 696a6f4dee..0000000000 --- a/pybamm/hints.py +++ /dev/null @@ -1,4 +0,0 @@ -from typing_extensions import TypeVar -import pybamm - -S = TypeVar("S", bound=pybamm.Symbol) From 5333d422f9bd882442c0f681cd7b48de0a9cf999 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 18 Jan 2024 16:23:24 +0000 Subject: [PATCH 23/32] Further specify types for domain/auxiliary_domain/domains --- pybamm/expression_tree/array.py | 4 ++-- pybamm/expression_tree/broadcasts.py | 10 +++++----- pybamm/expression_tree/concatenations.py | 2 +- pybamm/expression_tree/independent_variable.py | 12 ++++++------ pybamm/expression_tree/matrix.py | 4 ++-- pybamm/expression_tree/state_vector.py | 18 +++++++++--------- pybamm/expression_tree/symbol.py | 6 +++--- pybamm/expression_tree/unary_operators.py | 14 ++++++++++++-- pybamm/expression_tree/variable.py | 4 ++-- pybamm/expression_tree/vector.py | 8 +++++--- 10 files changed, 47 insertions(+), 35 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index a433dec822..b508e32cc6 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -41,11 +41,11 @@ class Array(pybamm.Symbol): def __init__( self, - entries: np.ndarray | list | csr_matrix, + entries: np.ndarray | list[float] | csr_matrix, name: str | None = None, domain: list[str] | str | None = None, auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str]] | None = None, + domains: dict[str, list[str] | str] | None = None, entries_string: str | None = None, ) -> None: # if diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 8e22ecd487..ced3ac8cc9 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -36,7 +36,7 @@ class Broadcast(pybamm.SpatialOperator): def __init__( self, child: pybamm.Symbol, - domains: dict[str, list[str]], + domains: dict[str, list[str] | str], name: str | None = None, ): if name is None: @@ -464,8 +464,8 @@ def __init__( self, child_input: NumberType | float | pybamm.Symbol, broadcast_domain: list[str] | str | None = None, - auxiliary_domains: str | dict | None = None, - broadcast_domains: dict | None = None, + auxiliary_domains: dict[str, str] | None = None, + broadcast_domains: dict[str, list[str] | str] | None = None, name: str | None = None, ): # Convert child to scalar if it is a number @@ -538,8 +538,8 @@ def __init__( self, child: NumberType | float | pybamm.Symbol, broadcast_domain: list[str] | str | None = None, - auxiliary_domains: str | dict | None = None, - broadcast_domains: dict | None = None, + auxiliary_domains: dict[str, str] | None = None, + broadcast_domains: dict[str, list[str] | str] | None = None, name: str | None = None, ): name = name or "broadcast to edges" diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 2d2ee055b3..06718c311d 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -422,7 +422,7 @@ class SparseStack(Concatenation): The equations to concatenate """ - def __init__(self, *children): #: Iterable[pybamm.Concatenation] + def __init__(self, *children): children = list(children) if not any(issparse(child.evaluate_for_shape()) for child in children): concatenation_function = np.vstack diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 3599e2656d..a33b4cd3e4 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -36,8 +36,8 @@ def __init__( self, name: str, domain: list[str] | str | None = None, - auxiliary_domains: dict | None = None, - domains: dict | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict[str, list[str] | str] | None = None, ) -> None: super().__init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains @@ -129,8 +129,8 @@ def __init__( self, name: str, domain: list[str] | str | None = None, - auxiliary_domains: dict | None = None, - domains: dict | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict[str, list[str] | str] | None = None, coord_sys=None, ) -> None: self.coord_sys = coord_sys @@ -190,8 +190,8 @@ def __init__( self, name: str, domain: list[str] | str | None = None, - auxiliary_domains: dict | None = None, - domains: dict | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict[str, list[str] | str] | None = None, coord_sys=None, ) -> None: super().__init__(name, domain, auxiliary_domains, domains, coord_sys) diff --git a/pybamm/expression_tree/matrix.py b/pybamm/expression_tree/matrix.py index c19c08c3d9..e4bd62d1ba 100644 --- a/pybamm/expression_tree/matrix.py +++ b/pybamm/expression_tree/matrix.py @@ -15,11 +15,11 @@ class Matrix(pybamm.Array): def __init__( self, - entries: np.ndarray | list | csr_matrix, + entries: np.ndarray | list[float] | csr_matrix, name: str | None = None, domain: list[str] | None = None, auxiliary_domains: dict[str, str] | None = None, - domains: dict | None = None, + domains: dict[str, list[str] | str] | None = None, entries_string: str | None = None, ) -> None: if isinstance(entries, list): diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index cc35b7ccbc..0861c6a56c 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -40,9 +40,9 @@ def __init__( base_name="y", name: str | None = None, domain: list[str] | str | None = None, - auxiliary_domains: dict | None = None, - domains: dict[str, list[str]] | None = None, - evaluation_array: list | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict[str, list[str] | str] | None = None, + evaluation_array: list[bool] | None = None, ): for y_slice in y_slices: if not isinstance(y_slice, slice): @@ -260,9 +260,9 @@ def __init__( *y_slices: slice, name: str | None = None, domain: list[str] | str | None = None, - auxiliary_domains: dict | None = None, - domains: dict[str, list[str]] | None = None, - evaluation_array: list | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict[str, list[str] | str] | None = None, + evaluation_array: list[bool] | None = None, ): super().__init__( *y_slices, @@ -344,9 +344,9 @@ def __init__( *y_slices: slice, name: str | None = None, domain: list[str] | str | None = None, - auxiliary_domains: dict | None = None, - domains: dict[str, list[str]] | None = None, - evaluation_array: list | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict[str, list[str] | str] | None = None, + evaluation_array: list[bool] | None = None, ): super().__init__( *y_slices, diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index a9efed12e4..ca1b32073d 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -56,7 +56,7 @@ def create_object_of_size(size: int, typ="vector"): return np.nan * np.ones((size, size)) -def evaluate_for_shape_using_domain(domains: dict, typ="vector"): +def evaluate_for_shape_using_domain(domains: dict[str, list[str] | str], typ="vector"): """ Return a vector of the appropriate shape, based on the domains. Domain 'sizes' can clash, but are unlikely to, and won't cause failures if they do. @@ -213,7 +213,7 @@ def __init__( children: Sequence[Symbol] | None = None, domain: list[str] | str | None = None, auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str]] | None = None, + domains: dict[str, list[str] | str] | None = None, ): super().__init__() self.name = name @@ -402,7 +402,7 @@ def read_domain_or_domains( self, domain: list[str] | str | None, auxiliary_domains: dict[str, str] | None, - domains: dict | None, + domains: dict[str, list[str] | str] | None, ): if domains is None: if isinstance(domain, str): diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 9cb1bf648e..e95cf8e00c 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -25,7 +25,12 @@ class UnaryOperator(pybamm.Symbol): child node """ - def __init__(self, name: str, child: pybamm.Symbol, domains: dict | None = None): + def __init__( + self, + name: str, + child: pybamm.Symbol, + domains: dict[str, list[str] | str] | None = None, + ): if isinstance(child, numbers.Number): child = pybamm.Scalar(child) domains = domains or child.domains @@ -395,7 +400,12 @@ class with a :class:`Matrix` child node """ - def __init__(self, name: str, child: pybamm.Symbol, domains: dict | None = None): + def __init__( + self, + name: str, + child: pybamm.Symbol, + domains: dict[str, list[str] | str] | None = None, + ): super().__init__(name, child, domains) def to_json(self): diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index 22992c3b63..98103d6d25 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -51,8 +51,8 @@ def __init__( self, name: str, domain: list[str] | str | None = None, - auxiliary_domains: dict | None = None, - domains: dict | None = None, + auxiliary_domains: dict[str, str] | None = None, + domains: dict[str, list[str] | str] | None = None, bounds: tuple[pybamm.Symbol] | None = None, print_name: str | None = None, scale: float | pybamm.Symbol | None = 1, diff --git a/pybamm/expression_tree/vector.py b/pybamm/expression_tree/vector.py index ed0f860290..5bf37d1ff1 100644 --- a/pybamm/expression_tree/vector.py +++ b/pybamm/expression_tree/vector.py @@ -14,11 +14,11 @@ class Vector(pybamm.Array): def __init__( self, - entries: np.ndarray | list | np.matrix, + entries: np.ndarray | list[float] | np.matrix, name: str | None = None, domain: list[str] | str | None = None, auxiliary_domains: dict[str, str] | None = None, - domains: dict | None = None, + domains: dict[str, list[str] | str] | None = None, entries_string: str | None = None, ) -> None: if isinstance(entries, (list, np.matrix)): @@ -30,7 +30,9 @@ def __init__( raise ValueError( """ Entries must have 1 dimension or be column vector, not have shape {} - """.format(entries.shape) + """.format( + entries.shape + ) ) if name is None: name = f"Column vector of length {entries.shape[0]!s}" From f7ff456df3ee2de904e86070d44de65ebd8c75e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:29:29 +0000 Subject: [PATCH 24/32] style: pre-commit fixes --- pybamm/expression_tree/vector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pybamm/expression_tree/vector.py b/pybamm/expression_tree/vector.py index 5bf37d1ff1..0b67d9b389 100644 --- a/pybamm/expression_tree/vector.py +++ b/pybamm/expression_tree/vector.py @@ -30,9 +30,7 @@ def __init__( raise ValueError( """ Entries must have 1 dimension or be column vector, not have shape {} - """.format( - entries.shape - ) + """.format(entries.shape) ) if name is None: name = f"Column vector of length {entries.shape[0]!s}" From 9cef34bf380c84c85dab1b869c14e492857dd3aa Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 19 Jan 2024 11:17:04 +0000 Subject: [PATCH 25/32] Stop ignoring UP007 in ruff, as per #3579 --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6bd016bb56..ed879648f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,7 +218,6 @@ ignore = [ "RET506", # Unnecessary `elif` "B018", # Found useless expression "RUF002", # Docstring contains ambiguous - "UP007", # For pyupgrade ] [tool.ruff.lint.per-file-ignores] From 21e410796aa65d3308eaffca78343543239b2d3e Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Mon, 22 Jan 2024 10:54:35 +0000 Subject: [PATCH 26/32] Move mypy.ini to pyproject.toml --- mypy.ini | 10 ---------- pyproject.toml | 13 +++++++++++++ 2 files changed, 13 insertions(+), 10 deletions(-) delete mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 5c97998a83..0000000000 --- a/mypy.ini +++ /dev/null @@ -1,10 +0,0 @@ -[mypy] -ignore_missing_imports = True -allow_redefinition = True -disable_error_code = call-overload, operator - -[mypy-pybamm.models.base_model.*] -disable_error_code = attr-defined - -[mypy-pybamm.models.full_battery_models.base_battery_model.*] -disable_error_code = attr-defined diff --git a/pyproject.toml b/pyproject.toml index ed879648f1..a9ac0c7f5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,6 +218,7 @@ ignore = [ "RET506", # Unnecessary `elif` "B018", # Found useless expression "RUF002", # Docstring contains ambiguous + "UP007", # For pyupgrade ] [tool.ruff.lint.per-file-ignores] @@ -268,3 +269,15 @@ concurrency = ["multiprocessing"] ignore = [ "PP003" # list wheel as a build-dep ] + +[tool.mypy] +ignore_missing_imports = true +allow_redefinition = true +disable_error_code = ["call-overload", "operator"] + +[[tool.mypy.overrides]] +module = [ + "pybamm.models.base_model.*", + "pybamm.models.full_battery_models.base_battery_model.*" + ] +disable_error_code = "attr-defined" From d1ae8198f2fe30d44adc27a5bef072f12d37c439 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 25 Jan 2024 16:45:56 +0000 Subject: [PATCH 27/32] Fix some ignored type errors --- pybamm/experiment/experiment.py | 4 ++-- pybamm/expression_tree/broadcasts.py | 12 ++++++------ pybamm/expression_tree/functions.py | 6 +++--- pybamm/expression_tree/interpolant.py | 6 +++--- pybamm/expression_tree/operations/evaluate_python.py | 4 +--- pybamm/expression_tree/operations/serialise.py | 9 ++++++--- 6 files changed, 21 insertions(+), 20 deletions(-) diff --git a/pybamm/experiment/experiment.py b/pybamm/experiment/experiment.py index b66fe187be..c2d43ec25b 100644 --- a/pybamm/experiment/experiment.py +++ b/pybamm/experiment/experiment.py @@ -43,7 +43,7 @@ class Experiment: def __init__( self, - operating_conditions: list[str], + operating_conditions: list[str | tuple[str]], period: str = "1 minute", temperature: float | None = None, termination: list[str] | None = None, @@ -73,7 +73,7 @@ def __init__( for cycle in operating_conditions: # Check types and convert to list if not isinstance(cycle, tuple): - cycle = (cycle,) # type: ignore[assignment] + cycle = (cycle,) operating_conditions_cycles.append(cycle) self.operating_conditions_cycles = operating_conditions_cycles diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index ced3ac8cc9..02bfd38a21 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -590,18 +590,18 @@ def full_like(symbols: tuple[pybamm.Symbol, ...], fill_value: float) -> pybamm.S return pybamm.Scalar(fill_value) try: shape = sum_symbol.shape - # use vector or matrix - if shape[1] == 1: - array_type: type[pybamm.Vector] = pybamm.Vector - else: - array_type: type[pybamm.Matrix] = pybamm.Matrix # type:ignore[no-redef] + # return dense array, except for a matrix of zeros if shape[1] != 1 and fill_value == 0: entries = csr_matrix(shape) else: entries = fill_value * np.ones(shape) - return array_type(entries, domains=sum_symbol.domains) + # use vector or matrix + if shape[1] == 1: + return pybamm.Vector(entries, domains=sum_symbol.domains) + else: + return pybamm.Matrix(entries, domains=sum_symbol.domains) except NotImplementedError: if ( diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 1c4271c42d..3a73b05f9c 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -75,7 +75,7 @@ def diff(self, variable: pybamm.Symbol): return pybamm.Scalar(1) else: children = self.orphans - partial_derivatives = [None] * len(children) + partial_derivatives: list[None | pybamm.Symbol] = [None] * len(children) for i, child in enumerate(self.children): # if variable appears in the function, differentiate # function, and apply chain rule @@ -87,9 +87,9 @@ def diff(self, variable: pybamm.Symbol): # remove None entries partial_derivatives = [x for x in partial_derivatives if x is not None] - derivative = sum(partial_derivatives) # type: ignore[arg-type] + derivative = sum(partial_derivatives) if derivative == 0: - derivative = pybamm.Scalar(0) # type: ignore[assignment] + return pybamm.Scalar(0) return derivative diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index ba3f183a30..665eb0e63f 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -129,14 +129,14 @@ def __init__( self.dimension = 1 if interpolator == "linear": if extrapolate is False: - fill_value = np.nan + fill_value_1: float | str = np.nan elif extrapolate is True: - fill_value = "extrapolate" # type: ignore[assignment] + fill_value_1 = "extrapolate" interpolating_function = interpolate.interp1d( x1, y.T, bounds_error=False, - fill_value=fill_value, + fill_value=fill_value_1, ) elif interpolator == "cubic": interpolating_function = interpolate.CubicSpline( diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 5c88961069..3582db13aa 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -70,9 +70,7 @@ def dot_product(self, b): """ # assume b is a column vector result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype) - return result.at[self.row].add( - self.data.reshape(-1, 1) * b[self.col] # type:ignore[index] - ) + return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col]) def scalar_multiply(self, b: float): """ diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index 33a2d4ba21..0507b3304e 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -247,12 +247,15 @@ def _get_pybamm_class(self, snippet: dict): try: empty_class = self._Empty() empty_class.__class__ = class_ + + return empty_class + except TypeError: # Mesh objects have a different layouts - empty_class = self._EmptyDict() # type: ignore[assignment] - empty_class.__class__ = class_ + empty_dict_class = self._EmptyDict() + empty_dict_class.__class__ = class_ - return empty_class + return empty_dict_class def _deconstruct_pybamm_dicts(self, dct: dict): """ From d80c981e9b37e4aa623a0ce90b1d32832fdee56d Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 8 Feb 2024 14:05:12 +0000 Subject: [PATCH 28/32] Move common type definitions to type_definitions.py --- pybamm/expression_tree/array.py | 15 +- pybamm/expression_tree/binary_operators.py | 128 +++++++++--------- pybamm/expression_tree/broadcasts.py | 21 ++- .../expression_tree/independent_variable.py | 27 ++-- pybamm/expression_tree/input_parameter.py | 4 +- pybamm/expression_tree/matrix.py | 7 +- pybamm/expression_tree/scalar.py | 10 +- pybamm/expression_tree/state_vector.py | 32 ++--- pybamm/expression_tree/symbol.py | 47 ++++--- pybamm/expression_tree/unary_operators.py | 5 +- pybamm/expression_tree/variable.py | 7 +- pybamm/expression_tree/vector.py | 7 +- pybamm/type_definitions.py | 14 ++ 13 files changed, 177 insertions(+), 147 deletions(-) create mode 100644 pybamm/type_definitions.py diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index b508e32cc6..0bb6168a7c 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -8,6 +8,7 @@ import pybamm from pybamm.util import have_optional_dependency +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType if TYPE_CHECKING: # pragma: no cover import sympy @@ -43,9 +44,9 @@ def __init__( self, entries: np.ndarray | list[float] | csr_matrix, name: str | None = None, - domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, entries_string: str | None = None, ) -> None: # if @@ -140,7 +141,13 @@ def create_copy(self): entries_string=self.entries_string, ) - def _base_evaluate(self, t, y, y_dot, inputs): + def _base_evaluate( + self, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, + ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" return self._entries diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index af00e1eeb0..c27f9267d6 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -11,14 +11,14 @@ import pybamm from pybamm.util import have_optional_dependency -from typing import Union, Callable +from typing import Callable # create type alias(s) -ChildValue = Union[float, np.ndarray, pybamm.Symbol] +from pybamm.type_definitions import ChildSymbol, ChildValue def _preprocess_binary( - left: ChildValue, right: ChildValue + left: ChildSymbol, right: ChildSymbol ) -> tuple[pybamm.Symbol, pybamm.Symbol]: if isinstance(left, numbers.Number): left = pybamm.Scalar(left) @@ -69,7 +69,7 @@ class BinaryOperator(pybamm.Symbol): """ def __init__( - self, name: str, left_child: ChildValue, right_child: ChildValue + self, name: str, left_child: ChildSymbol, right_child: ChildSymbol ) -> None: left, right = _preprocess_binary(left_child, right_child) @@ -128,7 +128,7 @@ def create_copy(self): return out - def _binary_new_copy(self, left: ChildValue, right: ChildValue): + def _binary_new_copy(self, left: ChildSymbol, right: ChildSymbol): """ Default behaviour for new_copy. This copies the behaviour of `_binary_evaluate`, but since `left` and `right` @@ -206,8 +206,8 @@ class Power(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("**", left, right) @@ -239,8 +239,8 @@ def _binary_jac(self, left_jac, right_jac): def _binary_evaluate( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" # don't raise RuntimeWarning for NaNs @@ -255,8 +255,8 @@ class Addition(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("+", left, right) @@ -265,11 +265,11 @@ def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) + self.right.diff(variable) - def _binary_jac(self, left_jac: float | np.ndarray, right_jac: float | np.ndarray): + def _binary_jac(self, left_jac: ChildValue, right_jac: ChildValue): """See :meth:`pybamm.BinaryOperator._binary_jac()`.""" return left_jac + right_jac - def _binary_evaluate(self, left: float | np.ndarray, right: float | np.ndarray): + def _binary_evaluate(self, left: ChildValue, right: ChildValue): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" return left + right @@ -281,8 +281,8 @@ class Subtraction(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" @@ -296,7 +296,7 @@ def _binary_jac(self, left_jac, right_jac): """See :meth:`pybamm.BinaryOperator._binary_jac()`.""" return left_jac - right_jac - def _binary_evaluate(self, left: float | np.ndarray, right: float | np.ndarray): + def _binary_evaluate(self, left: ChildValue, right: ChildValue): """See :meth:`pybamm.BinaryOperator._binary_evaluate()`.""" return left - right @@ -310,8 +310,8 @@ class Multiplication(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" @@ -351,8 +351,8 @@ class MatrixMultiplication(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("@", left, right) @@ -401,8 +401,8 @@ class Division(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("/", left, right) @@ -449,8 +449,8 @@ class Inner(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("inner product", left, right) @@ -485,8 +485,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.inner(left, right) @@ -526,8 +526,8 @@ class Equality(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("==", left, right) @@ -554,8 +554,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.Equality(left, right) @@ -581,8 +581,8 @@ class _Heaviside(BinaryOperator): def __init__( self, name: str, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__(name, left, right) @@ -616,8 +616,8 @@ class EqualHeaviside(_Heaviside): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("<=", left, right) @@ -638,8 +638,8 @@ class NotEqualHeaviside(_Heaviside): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): super().__init__("<", left, right) @@ -659,8 +659,8 @@ class Modulo(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): super().__init__("%", left, right) @@ -701,8 +701,8 @@ class Minimum(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): super().__init__("minimum", left, right) @@ -729,8 +729,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.minimum(left, right) @@ -746,8 +746,8 @@ class Maximum(BinaryOperator): def __init__( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): super().__init__("maximum", left, right) @@ -774,8 +774,8 @@ def _binary_evaluate(self, left, right): def _binary_new_copy( self, - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """See :meth:`pybamm.BinaryOperator._binary_new_copy()`.""" return pybamm.maximum(left, right) @@ -787,8 +787,8 @@ def _sympy_operator(self, left, right): def _simplify_elementwise_binary_broadcasts( - left_child: ChildValue, - right_child: ChildValue, + left_child: ChildSymbol, + right_child: ChildSymbol, ) -> tuple[pybamm.Symbol, pybamm.Symbol]: left, right = _preprocess_binary(left_child, right_child) @@ -863,8 +863,8 @@ def _simplified_binary_broadcast_concatenation( def simplified_power( - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -907,7 +907,7 @@ def simplified_power( return pybamm.simplify_if_constant(pybamm.Power(left, right)) -def add(left: ChildValue, right: ChildValue): +def add(left: ChildSymbol, right: ChildSymbol): """ Note ---- @@ -996,8 +996,8 @@ def add(left: ChildValue, right: ChildValue): def subtract( - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """ Note @@ -1081,8 +1081,8 @@ def subtract( def multiply( - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -1209,8 +1209,8 @@ def multiply( def divide( - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): left, right = _simplify_elementwise_binary_broadcasts(left, right) @@ -1283,8 +1283,8 @@ def divide( def matmul( - left_child: ChildValue, - right_child: ChildValue, + left_child: ChildSymbol, + right_child: ChildSymbol, ): left, right = _preprocess_binary(left_child, right_child) if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): @@ -1345,8 +1345,8 @@ def matmul( def minimum( - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ) -> pybamm.Symbol: """ Returns the smaller of two objects, possibly with a smoothing approximation. @@ -1372,8 +1372,8 @@ def minimum( def maximum( - left: ChildValue, - right: ChildValue, + left: ChildSymbol, + right: ChildSymbol, ): """ Returns the larger of two objects, possibly with a smoothing approximation. @@ -1398,7 +1398,7 @@ def maximum( return pybamm.simplify_if_constant(out) -def _heaviside(left: ChildValue, right: ChildValue, equal): +def _heaviside(left: ChildSymbol, right: ChildSymbol, equal): """return a :class:`EqualHeaviside` object, or a smooth approximation.""" # Check for Concatenations and Broadcasts left, right = _simplify_elementwise_binary_broadcasts(left, right) diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 02bfd38a21..f7d697309a 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -6,11 +6,10 @@ import numpy as np from scipy.sparse import csr_matrix -from typing import Type, SupportsFloat +from typing import SupportsFloat import pybamm - -NumberType = Type[SupportsFloat] +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType class Broadcast(pybamm.SpatialOperator): @@ -462,10 +461,10 @@ class FullBroadcast(Broadcast): def __init__( self, - child_input: NumberType | float | pybamm.Symbol, - broadcast_domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - broadcast_domains: dict[str, list[str] | str] | None = None, + child_input: SupportsFloat | pybamm.Symbol, + broadcast_domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + broadcast_domains: DomainsType = None, name: str | None = None, ): # Convert child to scalar if it is a number @@ -536,10 +535,10 @@ class FullBroadcastToEdges(FullBroadcast): def __init__( self, - child: NumberType | float | pybamm.Symbol, - broadcast_domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - broadcast_domains: dict[str, list[str] | str] | None = None, + child: SupportsFloat | pybamm.Symbol, + broadcast_domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + broadcast_domains: DomainsType = None, name: str | None = None, ): name = name or "broadcast to edges" diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index a33b4cd3e4..0dca6dba46 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -3,10 +3,11 @@ # from __future__ import annotations import sympy -from typing import Any +import numpy as np import pybamm from pybamm.util import have_optional_dependency +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType KNOWN_COORD_SYS = ["cartesian", "cylindrical polar", "spherical polar"] @@ -35,9 +36,9 @@ class IndependentVariable(pybamm.Symbol): def __init__( self, name: str, - domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, ) -> None: super().__init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains @@ -83,9 +84,9 @@ def create_copy(self): def _base_evaluate( self, t: float | None = None, - y: Any = None, - y_dot: Any = None, - inputs: Any = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if t is None: @@ -128,9 +129,9 @@ class SpatialVariable(IndependentVariable): def __init__( self, name: str, - domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, coord_sys=None, ) -> None: self.coord_sys = coord_sys @@ -189,9 +190,9 @@ class SpatialVariableEdge(SpatialVariable): def __init__( self, name: str, - domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, coord_sys=None, ) -> None: super().__init__(name, domain, auxiliary_domains, domains, coord_sys) diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index d8e4a7364f..a1ee00a47b 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -7,6 +7,8 @@ import scipy.sparse import pybamm +from pybamm.type_definitions import DomainType + class InputParameter(pybamm.Symbol): """ @@ -29,7 +31,7 @@ class InputParameter(pybamm.Symbol): def __init__( self, name: str, - domain: list[str] | str | None = None, + domain: DomainType = None, expected_size: int | None = None, ) -> None: # Expected size defaults to 1 if no domain else None (gets set later) diff --git a/pybamm/expression_tree/matrix.py b/pybamm/expression_tree/matrix.py index e4bd62d1ba..435d06d84f 100644 --- a/pybamm/expression_tree/matrix.py +++ b/pybamm/expression_tree/matrix.py @@ -6,6 +6,7 @@ from scipy.sparse import csr_matrix, issparse import pybamm +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType class Matrix(pybamm.Array): @@ -17,9 +18,9 @@ def __init__( self, entries: np.ndarray | list[float] | csr_matrix, name: str | None = None, - domain: list[str] | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, entries_string: str | None = None, ) -> None: if isinstance(entries, list): diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index 0c9d349484..c9852d5fb2 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -4,7 +4,7 @@ from __future__ import annotations import numbers import numpy as np -from typing import Literal, Any +from typing import Literal import pybamm from pybamm.util import have_optional_dependency @@ -65,10 +65,10 @@ def set_id(self): def _base_evaluate( self, - t: Any = None, - y: Any = None, - y_dot: Any = None, - inputs: Any = None, + t: float | None = None, + y: np.ndarray | None = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" return self._value diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 0861c6a56c..f52b41a99c 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -4,9 +4,9 @@ from __future__ import annotations import numpy as np from scipy.sparse import csr_matrix, vstack -from typing import Any import pybamm +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType class StateVectorBase(pybamm.Symbol): @@ -39,9 +39,9 @@ def __init__( *y_slices: slice, base_name="y", name: str | None = None, - domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, evaluation_array: list[bool] | None = None, ): for y_slice in y_slices: @@ -259,9 +259,9 @@ def __init__( self, *y_slices: slice, name: str | None = None, - domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, evaluation_array: list[bool] | None = None, ): super().__init__( @@ -276,10 +276,10 @@ def __init__( def _base_evaluate( self, - t: Any = None, + t: float | None = None, y: np.ndarray | None = None, - y_dot: Any = None, - inputs: Any = None, + y_dot: np.ndarray | None = None, + inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if y is None: @@ -343,9 +343,9 @@ def __init__( self, *y_slices: slice, name: str | None = None, - domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, evaluation_array: list[bool] | None = None, ): super().__init__( @@ -360,10 +360,10 @@ def __init__( def _base_evaluate( self, - t: Any = None, - y: Any = None, + t: float | None = None, + y: np.ndarray | None = None, y_dot: np.ndarray | None = None, - inputs: Any = None, + inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" if y_dot is None: diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 3e84eb64a9..8e569b112b 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -15,6 +15,13 @@ if TYPE_CHECKING: # pragma: no cover import casadi + from pybamm.type_definitions import ( + ChildSymbol, + ChildValue, + DomainType, + AuxiliaryDomainType, + DomainsType, + ) DOMAIN_LEVELS = ["primary", "secondary", "tertiary", "quaternary"] EMPTY_DOMAINS: dict[str, list] = {k: [] for k in DOMAIN_LEVELS} @@ -211,9 +218,9 @@ def __init__( self, name: str, children: Sequence[Symbol] | None = None, - domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, ): super().__init__() self.name = name @@ -400,9 +407,9 @@ def get_children_domains(self, children: Sequence[Symbol]): def read_domain_or_domains( self, - domain: list[str] | str | None, - auxiliary_domains: dict[str, str] | None, - domains: dict[str, list[str] | str] | None, + domain: DomainType, + auxiliary_domains: AuxiliaryDomainType, + domains: DomainsType, ): if domains is None: if isinstance(domain, str): @@ -572,51 +579,47 @@ def __repr__(self): {k: v for k, v in self.domains.items() if v != []}, ) - def __add__(self, other: Symbol | float | np.ndarray) -> pybamm.Addition: + def __add__(self, other: ChildSymbol) -> pybamm.Addition: """return an :class:`Addition` object.""" return pybamm.add(self, other) - def __radd__(self, other: Symbol | float | np.ndarray) -> pybamm.Addition: + def __radd__(self, other: ChildSymbol) -> pybamm.Addition: """return an :class:`Addition` object.""" return pybamm.add(other, self) - def __sub__(self, other: Symbol | float | np.ndarray) -> pybamm.Subtraction: + def __sub__(self, other: ChildSymbol) -> pybamm.Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(self, other) - def __rsub__(self, other: Symbol | float | np.ndarray) -> pybamm.Subtraction: + def __rsub__(self, other: ChildSymbol) -> pybamm.Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(other, self) - def __mul__(self, other: Symbol | float | np.ndarray) -> pybamm.Multiplication: + def __mul__(self, other: ChildSymbol) -> pybamm.Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(self, other) - def __rmul__(self, other: Symbol | float | np.ndarray) -> pybamm.Multiplication: + def __rmul__(self, other: ChildSymbol) -> pybamm.Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(other, self) - def __matmul__( - self, other: Symbol | float | np.ndarray - ) -> pybamm.MatrixMultiplication: + def __matmul__(self, other: ChildSymbol) -> pybamm.MatrixMultiplication: """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(self, other) - def __rmatmul__( - self, other: Symbol | float | np.ndarray - ) -> pybamm.MatrixMultiplication: + def __rmatmul__(self, other: ChildSymbol) -> pybamm.MatrixMultiplication: """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(other, self) - def __truediv__(self, other: Symbol | float | np.ndarray) -> pybamm.Division: + def __truediv__(self, other: ChildSymbol) -> pybamm.Division: """return a :class:`Division` object.""" return pybamm.divide(self, other) - def __rtruediv__(self, other: Symbol | float | np.ndarray) -> pybamm.Division: + def __rtruediv__(self, other: ChildSymbol) -> pybamm.Division: """return a :class:`Division` object.""" return pybamm.divide(other, self) - def __pow__(self, other: Symbol | float | np.ndarray) -> pybamm.Power: + def __pow__(self, other: ChildSymbol) -> pybamm.Power: """return a :class:`Power` object.""" return pybamm.simplified_power(self, other) @@ -795,7 +798,7 @@ def evaluate( y: np.ndarray | None = None, y_dot: np.ndarray | None = None, inputs: dict | str | None = None, - ) -> float | np.ndarray: + ) -> ChildValue: """Evaluate expression tree (wrapper to allow using dict of known values). Parameters diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index e95cf8e00c..19e530e825 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -568,8 +568,9 @@ class Integral(SpatialOperator): def __init__( self, child, - integration_variable: list[pybamm.IndependentVariable] - | pybamm.IndependentVariable, + integration_variable: ( + list[pybamm.IndependentVariable] | pybamm.IndependentVariable + ), ): if not isinstance(integration_variable, list): integration_variable = [integration_variable] diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index 98103d6d25..f7b8a7b8e9 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -6,6 +6,7 @@ import numbers import pybamm from pybamm.util import have_optional_dependency +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType class VariableBase(pybamm.Symbol): @@ -50,9 +51,9 @@ class VariableBase(pybamm.Symbol): def __init__( self, name: str, - domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, bounds: tuple[pybamm.Symbol] | None = None, print_name: str | None = None, scale: float | pybamm.Symbol | None = 1, diff --git a/pybamm/expression_tree/vector.py b/pybamm/expression_tree/vector.py index 0b67d9b389..a1d8052c94 100644 --- a/pybamm/expression_tree/vector.py +++ b/pybamm/expression_tree/vector.py @@ -5,6 +5,7 @@ import numpy as np import pybamm +from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType class Vector(pybamm.Array): @@ -16,9 +17,9 @@ def __init__( self, entries: np.ndarray | list[float] | np.matrix, name: str | None = None, - domain: list[str] | str | None = None, - auxiliary_domains: dict[str, str] | None = None, - domains: dict[str, list[str] | str] | None = None, + domain: DomainType = None, + auxiliary_domains: AuxiliaryDomainType = None, + domains: DomainsType = None, entries_string: str | None = None, ) -> None: if isinstance(entries, (list, np.matrix)): diff --git a/pybamm/type_definitions.py b/pybamm/type_definitions.py new file mode 100644 index 0000000000..fc60ca2f0d --- /dev/null +++ b/pybamm/type_definitions.py @@ -0,0 +1,14 @@ +"""Common type definitions for PyBaMM""" + +from typing import Union +import numpy as np +import pybamm + + +# expression tree +ChildValue = Union[float, np.ndarray] +ChildSymbol = Union[float, np.ndarray, pybamm.Symbol] + +DomainType = Union[list[str], str, None] +AuxiliaryDomainType = Union[dict[str, str], None] +DomainsType = Union[dict[str, Union[list[str], str]], None] From 8b0a2aa1cd83cf2a6c26a7a31244bfac2fbd862f Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 8 Feb 2024 16:36:37 +0000 Subject: [PATCH 29/32] Replace numbers.Number type hints to work with static checkers --- pybamm/expression_tree/binary_operators.py | 15 +++++++----- pybamm/expression_tree/broadcasts.py | 28 ++++++++++++++-------- pybamm/expression_tree/functions.py | 3 +-- pybamm/expression_tree/parameter.py | 3 +-- pybamm/expression_tree/scalar.py | 4 ++-- pybamm/expression_tree/symbol.py | 10 ++++---- pybamm/expression_tree/unary_operators.py | 6 ++--- pybamm/expression_tree/variable.py | 9 +++++-- pybamm/type_definitions.py | 2 ++ 9 files changed, 47 insertions(+), 33 deletions(-) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index c27f9267d6..d10844798b 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -11,22 +11,22 @@ import pybamm from pybamm.util import have_optional_dependency -from typing import Callable +from typing import Callable, cast # create type alias(s) -from pybamm.type_definitions import ChildSymbol, ChildValue +from pybamm.type_definitions import ChildSymbol, ChildValue, Numeric def _preprocess_binary( left: ChildSymbol, right: ChildSymbol ) -> tuple[pybamm.Symbol, pybamm.Symbol]: - if isinstance(left, numbers.Number): + if isinstance(left, (float, int, np.number)): left = pybamm.Scalar(left) elif isinstance(left, np.ndarray): if left.ndim > 1: raise ValueError("left must be a 1D array") left = pybamm.Vector(left) - if isinstance(right, numbers.Number): + if isinstance(right, (float, int, np.number)): right = pybamm.Scalar(right) elif isinstance(right, np.ndarray): if right.ndim > 1: @@ -1494,7 +1494,7 @@ def sigmoid( def source( - left: numbers.Number | pybamm.Symbol, + left: Numeric | pybamm.Symbol, right: pybamm.Symbol, boundary=False, ): @@ -1513,7 +1513,7 @@ def source( Parameters ---------- - left : :class:`Symbol` + left : :class:`Symbol`, numeric The left child node, which represents the expression for the source term. right : :class:`Symbol` The right child node. This is the symbol whose boundary conditions are @@ -1528,6 +1528,9 @@ def source( if isinstance(left, numbers.Number): left = pybamm.PrimaryBroadcast(left, "current collector") + # force type cast for mypy + left = cast(pybamm.Symbol, left) + if left.domain != ["current collector"] or right.domain != ["current collector"]: raise pybamm.DomainError( f"""'source' only implemented in the 'current collector' domain, diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index f7d697309a..e17f375102 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -2,14 +2,18 @@ # Unary operator classes and methods # from __future__ import annotations -import numbers import numpy as np from scipy.sparse import csr_matrix -from typing import SupportsFloat +from typing import cast import pybamm -from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType +from pybamm.type_definitions import ( + DomainType, + AuxiliaryDomainType, + DomainsType, + Numeric, +) class Broadcast(pybamm.SpatialOperator): @@ -85,7 +89,7 @@ class PrimaryBroadcast(Broadcast): Parameters ---------- - child : :class:`Symbol` + child : :class:`Symbol`, numeric child node broadcast_domain : iterable of str Primary domain for broadcast. This will become the domain of the symbol @@ -95,13 +99,17 @@ class PrimaryBroadcast(Broadcast): def __init__( self, - child: numbers.Number | pybamm.Symbol, + child: Numeric | pybamm.Symbol, broadcast_domain: list[str] | str, name: str | None = None, ): # Convert child to scalar if it is a number - if isinstance(child, numbers.Number): + if isinstance(child, (float, int, np.number)): child = pybamm.Scalar(child) + + # cast child to Symbol for mypy + child = cast(pybamm.Symbol, child) + # Convert domain to list if it's a string if isinstance(broadcast_domain, str): broadcast_domain = [broadcast_domain] @@ -189,7 +197,7 @@ class PrimaryBroadcastToEdges(PrimaryBroadcast): def __init__( self, - child: numbers.Number | pybamm.Symbol, + child: Numeric | pybamm.Symbol, broadcast_domain: list[str] | str, name: str | None = None, ): @@ -461,14 +469,14 @@ class FullBroadcast(Broadcast): def __init__( self, - child_input: SupportsFloat | pybamm.Symbol, + child_input: Numeric | pybamm.Symbol, broadcast_domain: DomainType = None, auxiliary_domains: AuxiliaryDomainType = None, broadcast_domains: DomainsType = None, name: str | None = None, ): # Convert child to scalar if it is a number - if isinstance(child_input, numbers.Number): + if isinstance(child_input, (float, int, np.number)): child: pybamm.Scalar = pybamm.Scalar(child_input) else: child: pybamm.Symbol = child_input # type: ignore[no-redef] @@ -535,7 +543,7 @@ class FullBroadcastToEdges(FullBroadcast): def __init__( self, - child: SupportsFloat | pybamm.Symbol, + child: Numeric | pybamm.Symbol, broadcast_domain: DomainType = None, auxiliary_domains: AuxiliaryDomainType = None, broadcast_domains: DomainsType = None, diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 3a73b05f9c..72c9d4074a 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -2,7 +2,6 @@ # Function classes and methods # from __future__ import annotations -import numbers import numpy as np from scipy import special @@ -43,7 +42,7 @@ def __init__( # Turn numbers into scalars children = list(children) for idx, child in enumerate(children): - if isinstance(child, numbers.Number): + if isinstance(child, (float, int, np.number)): children[idx] = pybamm.Scalar(child) if name is not None: diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index 911e0ca59f..e646ff234d 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -2,7 +2,6 @@ # Parameter classes # from __future__ import annotations -import numbers import sys import numpy as np @@ -107,7 +106,7 @@ def __init__( # Turn numbers into scalars for idx, child in enumerate(children_list): - if isinstance(child, numbers.Number): + if isinstance(child, (float, int, np.number)): children_list[idx] = pybamm.Scalar(child) domains = self.get_children_domains(children_list) diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index c9852d5fb2..26bbabfcf0 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -2,12 +2,12 @@ # Scalar class # from __future__ import annotations -import numbers import numpy as np from typing import Literal import pybamm from pybamm.util import have_optional_dependency +from pybamm.type_definitions import Numeric class Scalar(pybamm.Symbol): @@ -26,7 +26,7 @@ class Scalar(pybamm.Symbol): def __init__( self, - value: float | numbers.Number | np.bool_, + value: Numeric, name: str | None = None, ) -> None: # set default name if not provided diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 8e569b112b..70f4e82db6 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -168,11 +168,9 @@ def simplify_if_constant(symbol: pybamm.Symbol): or (isinstance(result, np.ndarray) and result.ndim == 0) or isinstance(result, np.bool_) ): - if isinstance(result, np.ndarray): # pragma: no cover - # type-narrow for Scalar - new_result = cast(float, result) - return pybamm.Scalar(new_result) - return pybamm.Scalar(result) + # type-narrow for Scalar + new_result = cast(float, result) + return pybamm.Scalar(new_result) elif isinstance(result, np.ndarray) or issparse(result): if result.ndim == 1 or result.shape[1] == 1: return pybamm.Vector(result, domains=symbol.domains) @@ -850,7 +848,7 @@ def is_constant(self): # Default behaviour is False return False - def evaluate_ignoring_errors(self, t: float | None = 0): # none + def evaluate_ignoring_errors(self, t: float | None = 0): """ Evaluates the expression. If a node exists in the tree that cannot be evaluated as a scalar or vector (e.g. Time, Parameter, Variable, StateVector), then None diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 19e530e825..319499a9fc 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -2,12 +2,12 @@ # Unary operator classes and methods # from __future__ import annotations -import numbers import numpy as np from scipy.sparse import csr_matrix, issparse import pybamm from pybamm.util import have_optional_dependency +from pybamm.type_definitions import DomainsType class UnaryOperator(pybamm.Symbol): @@ -29,9 +29,9 @@ def __init__( self, name: str, child: pybamm.Symbol, - domains: dict[str, list[str] | str] | None = None, + domains: DomainsType = None, ): - if isinstance(child, numbers.Number): + if isinstance(child, (float, int, np.number)): child = pybamm.Scalar(child) domains = domains or child.domains diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index f7b8a7b8e9..eb0d90cdb6 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -6,7 +6,12 @@ import numbers import pybamm from pybamm.util import have_optional_dependency -from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType +from pybamm.type_definitions import ( + DomainType, + AuxiliaryDomainType, + DomainsType, + Numeric, +) class VariableBase(pybamm.Symbol): @@ -83,7 +88,7 @@ def bounds(self): return self._bounds @bounds.setter - def bounds(self, values: tuple[numbers.Number, numbers.Number]): + def bounds(self, values: tuple[Numeric, Numeric]): if values is None: values = (-np.inf, np.inf) else: diff --git a/pybamm/type_definitions.py b/pybamm/type_definitions.py index fc60ca2f0d..95350a675b 100644 --- a/pybamm/type_definitions.py +++ b/pybamm/type_definitions.py @@ -4,6 +4,8 @@ import numpy as np import pybamm +# numbers.Number should not be used for type hints +Numeric = Union[int, float, np.number] # expression tree ChildValue = Union[float, np.ndarray] From 102c2d61d48abee0df9dd3c6d8e248a649b2d375 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 8 Feb 2024 16:51:50 +0000 Subject: [PATCH 30/32] Add from __future__ to type_definitions.py --- pybamm/type_definitions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pybamm/type_definitions.py b/pybamm/type_definitions.py index 95350a675b..700de96634 100644 --- a/pybamm/type_definitions.py +++ b/pybamm/type_definitions.py @@ -1,5 +1,7 @@ """Common type definitions for PyBaMM""" +from __future__ import annotations + from typing import Union import numpy as np import pybamm From 38a37137dd364e746cb3312a3c6727902d082a62 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 9 Feb 2024 10:09:36 +0000 Subject: [PATCH 31/32] Add TypeAlias hint --- pybamm/type_definitions.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pybamm/type_definitions.py b/pybamm/type_definitions.py index 700de96634..4bccd4c3bd 100644 --- a/pybamm/type_definitions.py +++ b/pybamm/type_definitions.py @@ -1,18 +1,20 @@ -"""Common type definitions for PyBaMM""" - +# +# Common type definitions for PyBaMM +# from __future__ import annotations from typing import Union +from typing_extensions import TypeAlias import numpy as np import pybamm # numbers.Number should not be used for type hints -Numeric = Union[int, float, np.number] +Numeric: TypeAlias = Union[int, float, np.number] # expression tree -ChildValue = Union[float, np.ndarray] -ChildSymbol = Union[float, np.ndarray, pybamm.Symbol] +ChildValue: TypeAlias = Union[float, np.ndarray] +ChildSymbol: TypeAlias = Union[float, np.ndarray, pybamm.Symbol] -DomainType = Union[list[str], str, None] -AuxiliaryDomainType = Union[dict[str, str], None] -DomainsType = Union[dict[str, Union[list[str], str]], None] +DomainType: TypeAlias = Union[list[str], str, None] +AuxiliaryDomainType: TypeAlias = Union[dict[str, str], None] +DomainsType: TypeAlias = Union[dict[str, Union[list[str], str]], None] From ea03b4fd7e978084079fef6f222bdb5c4ee97de5 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 9 Feb 2024 10:58:51 +0000 Subject: [PATCH 32/32] Use typing List/Dict --- pybamm/type_definitions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pybamm/type_definitions.py b/pybamm/type_definitions.py index 4bccd4c3bd..c3d0a56faa 100644 --- a/pybamm/type_definitions.py +++ b/pybamm/type_definitions.py @@ -3,7 +3,7 @@ # from __future__ import annotations -from typing import Union +from typing import Union, List, Dict from typing_extensions import TypeAlias import numpy as np import pybamm @@ -15,6 +15,6 @@ ChildValue: TypeAlias = Union[float, np.ndarray] ChildSymbol: TypeAlias = Union[float, np.ndarray, pybamm.Symbol] -DomainType: TypeAlias = Union[list[str], str, None] -AuxiliaryDomainType: TypeAlias = Union[dict[str, str], None] -DomainsType: TypeAlias = Union[dict[str, Union[list[str], str]], None] +DomainType: TypeAlias = Union[List[str], str, None] +AuxiliaryDomainType: TypeAlias = Union[Dict[str, str], None] +DomainsType: TypeAlias = Union[Dict[str, Union[List[str], str]], None]