diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 5be8e15a1a..2601034b98 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -260,7 +260,7 @@ def _function_diff(self, children, idx): def arcsinh(child): " Returns arcsinh function of child. " - return Arcsinh(child) + return pybamm.simplify_if_constant(Arcsinh(child), keep_domains=True) class Cos(SpecificFunction): @@ -276,7 +276,7 @@ def _function_diff(self, children, idx): def cos(child): " Returns cosine function of child. " - return Cos(child) + return pybamm.simplify_if_constant(Cos(child), keep_domains=True) class Cosh(SpecificFunction): @@ -292,7 +292,7 @@ def _function_diff(self, children, idx): def cosh(child): " Returns hyperbolic cosine function of child. " - return Cosh(child) + return pybamm.simplify_if_constant(Cosh(child), keep_domains=True) class Exponential(SpecificFunction): @@ -308,7 +308,7 @@ def _function_diff(self, children, idx): def exp(child): " Returns exponential function of child. " - return Exponential(child) + return pybamm.simplify_if_constant(Exponential(child), keep_domains=True) class Log(SpecificFunction): @@ -330,7 +330,7 @@ def _function_diff(self, children, idx): def log(child, base="e"): " Returns logarithmic function of child (any base, default 'e'). " if base == "e": - return Log(child) + return pybamm.simplify_if_constant(Log(child), keep_domains=True) else: return Log(child) / np.log(base) @@ -342,17 +342,17 @@ def log10(child): def max(child): " Returns max function of child. " - return Function(np.max, child) + return pybamm.simplify_if_constant(Function(np.max, child), keep_domains=True) def min(child): " Returns min function of child. " - return Function(np.min, child) + return pybamm.simplify_if_constant(Function(np.min, child), keep_domains=True) def sech(child): " Returns hyperbolic sec function of child. " - return 1 / Cosh(child) + return pybamm.simplify_if_constant(1 / Cosh(child), keep_domains=True) class Sin(SpecificFunction): @@ -368,7 +368,7 @@ def _function_diff(self, children, idx): def sin(child): " Returns sine function of child. " - return Sin(child) + return pybamm.simplify_if_constant(Sin(child), keep_domains=True) class Sinh(SpecificFunction): @@ -384,7 +384,7 @@ def _function_diff(self, children, idx): def sinh(child): " Returns hyperbolic sine function of child. " - return Sinh(child) + return pybamm.simplify_if_constant(Sinh(child), keep_domains=True) class Sqrt(SpecificFunction): @@ -405,7 +405,7 @@ def _function_diff(self, children, idx): def sqrt(child): " Returns square root function of child. " - return Sqrt(child) + return pybamm.simplify_if_constant(Sqrt(child), keep_domains=True) class Tanh(SpecificFunction): @@ -421,4 +421,4 @@ def _function_diff(self, children, idx): def tanh(child): " Returns hyperbolic tan function of child. " - return Tanh(child) + return pybamm.simplify_if_constant(Tanh(child), keep_domains=True) diff --git a/pybamm/expression_tree/operations/simplify.py b/pybamm/expression_tree/operations/simplify.py index acfeef195e..6e090db82d 100644 --- a/pybamm/expression_tree/operations/simplify.py +++ b/pybamm/expression_tree/operations/simplify.py @@ -5,7 +5,7 @@ import numpy as np import numbers -from scipy.sparse import issparse +from scipy.sparse import issparse, csr_matrix def simplify_if_constant(symbol, keep_domains=False): @@ -32,6 +32,9 @@ def simplify_if_constant(symbol, keep_domains=False): result, domain=domain, auxiliary_domains=auxiliary_domains ) else: + # Turn matrix of zeros into sparse matrix + if isinstance(result, np.ndarray) and np.all(result == 0): + result = csr_matrix(result) return pybamm.Matrix( result, domain=domain, auxiliary_domains=auxiliary_domains ) diff --git a/tests/unit/test_expression_tree/test_functions.py b/tests/unit/test_expression_tree/test_functions.py index 585fdbc7b1..8180723295 100644 --- a/tests/unit/test_expression_tree/test_functions.py +++ b/tests/unit/test_expression_tree/test_functions.py @@ -23,7 +23,7 @@ def test_multi_var_function_cube(arg1, arg2): class TestFunction(unittest.TestCase): def test_number_input(self): # with numbers - log = pybamm.log(10) + log = pybamm.Function(np.log, 10) self.assertIsInstance(log.children[0], pybamm.Scalar) self.assertEqual(log.evaluate(), np.log(10)) @@ -127,27 +127,29 @@ def test_function_unnamed(self): class TestSpecificFunctions(unittest.TestCase): def test_arcsinh(self): - a = pybamm.Scalar(3) + a = pybamm.InputParameter("a") fun = pybamm.arcsinh(a) self.assertIsInstance(fun, pybamm.Arcsinh) - self.assertEqual(fun.evaluate(), np.arcsinh(3)) + self.assertEqual(fun.evaluate(u={"a": 3}), np.arcsinh(3)) h = 0.0000001 self.assertAlmostEqual( - fun.diff(a).evaluate(), - (pybamm.arcsinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h, + fun.diff(a).evaluate(u={"a": 3}), + (pybamm.arcsinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3})) + / h, places=5, ) def test_cos(self): - a = pybamm.Scalar(3) + a = pybamm.InputParameter("a") fun = pybamm.cos(a) self.assertIsInstance(fun, pybamm.Cos) self.assertEqual(fun.children[0].id, a.id) - self.assertEqual(fun.evaluate(), np.cos(3)) + self.assertEqual(fun.evaluate(u={"a": 3}), np.cos(3)) h = 0.0000001 self.assertAlmostEqual( - fun.diff(a).evaluate(), - (pybamm.cos(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h, + fun.diff(a).evaluate(u={"a": 3}), + (pybamm.cos(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3})) + / h, places=5, ) @@ -157,110 +159,120 @@ def test_cos(self): self.assertEqual(fun.id, fun.simplify().id) def test_cosh(self): - a = pybamm.Scalar(3) + a = pybamm.InputParameter("a") fun = pybamm.cosh(a) self.assertIsInstance(fun, pybamm.Cosh) self.assertEqual(fun.children[0].id, a.id) - self.assertEqual(fun.evaluate(), np.cosh(3)) + self.assertEqual(fun.evaluate(u={"a": 3}), np.cosh(3)) h = 0.0000001 self.assertAlmostEqual( - fun.diff(a).evaluate(), - (pybamm.cosh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h, + fun.diff(a).evaluate(u={"a": 3}), + (pybamm.cosh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3})) + / h, places=5, ) def test_exp(self): - a = pybamm.Scalar(3) + a = pybamm.InputParameter("a") fun = pybamm.exp(a) self.assertIsInstance(fun, pybamm.Exponential) self.assertEqual(fun.children[0].id, a.id) - self.assertEqual(fun.evaluate(), np.exp(3)) + self.assertEqual(fun.evaluate(u={"a": 3}), np.exp(3)) h = 0.0000001 self.assertAlmostEqual( - fun.diff(a).evaluate(), - (pybamm.exp(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h, + fun.diff(a).evaluate(u={"a": 3}), + (pybamm.exp(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3})) + / h, places=5, ) def test_log(self): - a = pybamm.Scalar(3) + a = pybamm.InputParameter("a") fun = pybamm.log(a) - self.assertEqual(fun.evaluate(), np.log(3)) + self.assertEqual(fun.evaluate(u={"a": 3}), np.log(3)) h = 0.0000001 self.assertAlmostEqual( - fun.diff(a).evaluate(), - (pybamm.log(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h, + fun.diff(a).evaluate(u={"a": 3}), + (pybamm.log(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3})) + / h, places=5, ) # Base 10 fun = pybamm.log10(a) - self.assertEqual(fun.evaluate(), np.log10(3)) + self.assertEqual(fun.evaluate(u={"a": 3}), np.log10(3)) h = 0.0000001 self.assertAlmostEqual( - fun.diff(a).evaluate(), - (pybamm.log10(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h, + fun.diff(a).evaluate(u={"a": 3}), + (pybamm.log10(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3})) + / h, places=5, ) def test_max(self): - a = pybamm.Vector(np.array([1, 2, 3])) + a = pybamm.StateVector(slice(0, 3)) + y_test = np.array([1, 2, 3]) fun = pybamm.max(a) self.assertIsInstance(fun, pybamm.Function) - self.assertEqual(fun.evaluate(), 3) + self.assertEqual(fun.evaluate(y=y_test), 3) def test_min(self): - a = pybamm.Vector(np.array([1, 2, 3])) + a = pybamm.StateVector(slice(0, 3)) + y_test = np.array([1, 2, 3]) fun = pybamm.min(a) self.assertIsInstance(fun, pybamm.Function) - self.assertEqual(fun.evaluate(), 1) + self.assertEqual(fun.evaluate(y=y_test), 1) def test_sin(self): - a = pybamm.Scalar(3) + a = pybamm.InputParameter("a") fun = pybamm.sin(a) self.assertIsInstance(fun, pybamm.Sin) self.assertEqual(fun.children[0].id, a.id) - self.assertEqual(fun.evaluate(), np.sin(3)) + self.assertEqual(fun.evaluate(u={"a": 3}), np.sin(3)) h = 0.0000001 self.assertAlmostEqual( - fun.diff(a).evaluate(), - (pybamm.sin(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h, + fun.diff(a).evaluate(u={"a": 3}), + (pybamm.sin(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3})) + / h, places=5, ) def test_sinh(self): - a = pybamm.Scalar(3) + a = pybamm.InputParameter("a") fun = pybamm.sinh(a) self.assertIsInstance(fun, pybamm.Sinh) self.assertEqual(fun.children[0].id, a.id) - self.assertEqual(fun.evaluate(), np.sinh(3)) + self.assertEqual(fun.evaluate(u={"a": 3}), np.sinh(3)) h = 0.0000001 self.assertAlmostEqual( - fun.diff(a).evaluate(), - (pybamm.sinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h, + fun.diff(a).evaluate(u={"a": 3}), + (pybamm.sinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3})) + / h, places=5, ) def test_sqrt(self): - a = pybamm.Scalar(3) + a = pybamm.InputParameter("a") fun = pybamm.sqrt(a) self.assertIsInstance(fun, pybamm.Sqrt) - self.assertEqual(fun.evaluate(), np.sqrt(3)) + self.assertEqual(fun.evaluate(u={"a": 3}), np.sqrt(3)) h = 0.0000001 self.assertAlmostEqual( - fun.diff(a).evaluate(), - (pybamm.sqrt(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h, + fun.diff(a).evaluate(u={"a": 3}), + (pybamm.sqrt(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3})) + / h, places=5, ) def test_tanh(self): - a = pybamm.Scalar(3) + a = pybamm.InputParameter("a") fun = pybamm.tanh(a) - self.assertEqual(fun.evaluate(), np.tanh(3)) + self.assertEqual(fun.evaluate(u={"a": 3}), np.tanh(3)) h = 0.0000001 self.assertAlmostEqual( - fun.diff(a).evaluate(), - (pybamm.tanh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h, + fun.diff(a).evaluate(u={"a": 3}), + (pybamm.tanh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3})) + / h, places=5, ) diff --git a/tests/unit/test_expression_tree/test_matrix.py b/tests/unit/test_expression_tree/test_matrix.py index 1575c8d19e..b135bd75cd 100644 --- a/tests/unit/test_expression_tree/test_matrix.py +++ b/tests/unit/test_expression_tree/test_matrix.py @@ -29,13 +29,6 @@ def test_matrix_operations(self): (self.mat @ self.vect).evaluate(), np.array([[5], [2], [3]]) ) - def test_matrix_modification(self): - exp = self.mat @ self.mat + self.mat - self.A[0, 0] = -1 - self.assertTrue(exp.children[1]._entries[0, 0], -1) - self.assertTrue(exp.children[0].children[0]._entries[0, 0], -1) - self.assertTrue(exp.children[0].children[1]._entries[0, 0], -1) - class TestArray(unittest.TestCase): def test_name(self): diff --git a/tests/unit/test_expression_tree/test_operations/test_simplify.py b/tests/unit/test_expression_tree/test_operations/test_simplify.py index 9c47da4401..129aef5e98 100644 --- a/tests/unit/test_expression_tree/test_operations/test_simplify.py +++ b/tests/unit/test_expression_tree/test_operations/test_simplify.py @@ -518,7 +518,7 @@ def test_matrix_divide_simplify(self): expr3 = (m / a).simplify() self.assertIsInstance(expr3, pybamm.Matrix) self.assertEqual(expr3.shape, m.shape) - np.testing.assert_array_equal(expr3.evaluate(), np.zeros((10, 10))) + np.testing.assert_array_equal(expr3.evaluate().toarray(), np.zeros((10, 10))) def test_domain_concatenation_simplify(self): # create discretisation diff --git a/tests/unit/test_expression_tree/test_vector.py b/tests/unit/test_expression_tree/test_vector.py index 86e54a0526..924876dd83 100644 --- a/tests/unit/test_expression_tree/test_vector.py +++ b/tests/unit/test_expression_tree/test_vector.py @@ -31,13 +31,6 @@ def test_vector_operations(self): (self.vect * self.vect).evaluate(), np.array([[1], [4], [9]]) ) - def test_vector_modification(self): - exp = self.vect * self.vect + self.vect - self.x[0] = -1 - self.assertTrue(exp.children[1]._entries[0], -1) - self.assertTrue(exp.children[0].children[0]._entries[0], -1) - self.assertTrue(exp.children[0].children[1]._entries[0], -1) - def test_wrong_size_entries(self): with self.assertRaisesRegex( ValueError, "Entries must have 1 dimension or be column vector" diff --git a/tests/unit/test_parameters/test_geometric_parameters.py b/tests/unit/test_parameters/test_geometric_parameters.py index 64b87695c6..2711c76430 100644 --- a/tests/unit/test_parameters/test_geometric_parameters.py +++ b/tests/unit/test_parameters/test_geometric_parameters.py @@ -31,7 +31,6 @@ def test_macroscale_parameters(self): self.assertEqual( (L_n_eval + L_s_eval + L_p_eval).evaluate(), L_x_eval.evaluate() ) - self.assertEqual((L_n_eval + L_s_eval + L_p_eval).id, L_x_eval.id) l_n_eval = parameter_values.process_symbol(l_n) l_s_eval = parameter_values.process_symbol(l_s) l_p_eval = parameter_values.process_symbol(l_p) diff --git a/tests/unit/test_parameters/test_parameter_values.py b/tests/unit/test_parameters/test_parameter_values.py index 5d6d4e9115..0c8750f2cc 100644 --- a/tests/unit/test_parameters/test_parameter_values.py +++ b/tests/unit/test_parameters/test_parameter_values.py @@ -279,25 +279,23 @@ def test_process_function_parameter(self): "const": 254, } ) - a = pybamm.Parameter("a") + a = pybamm.InputParameter("a") # process function func = pybamm.FunctionParameter("func", a) processed_func = parameter_values.process_symbol(func) - self.assertEqual(processed_func.evaluate(), 369) + self.assertEqual(processed_func.evaluate(u={"a": 3}), 369) # process constant function const = pybamm.FunctionParameter("const", a) processed_const = parameter_values.process_symbol(const) - self.assertIsInstance(processed_const, pybamm.Multiplication) - self.assertIsInstance(processed_const.left, pybamm.Scalar) - self.assertIsInstance(processed_const.right, pybamm.Scalar) + self.assertIsInstance(processed_const, pybamm.Scalar) self.assertEqual(processed_const.evaluate(), 254) # process differentiated function parameter diff_func = func.diff(a) processed_diff_func = parameter_values.process_symbol(diff_func) - self.assertEqual(processed_diff_func.evaluate(), 123) + self.assertEqual(processed_diff_func.evaluate(u={"a": 3}), 123) def test_process_inline_function_parameters(self): def D(c): diff --git a/tests/unit/test_spatial_methods/test_finite_volume/test_finite_volume.py b/tests/unit/test_spatial_methods/test_finite_volume/test_finite_volume.py index 96de8693ed..af36bdcfc5 100644 --- a/tests/unit/test_spatial_methods/test_finite_volume/test_finite_volume.py +++ b/tests/unit/test_spatial_methods/test_finite_volume/test_finite_volume.py @@ -917,7 +917,7 @@ def test_discretise_spatial_variable(self): r = 3 * pybamm.SpatialVariable("r", ["negative particle"]) r_disc = disc.process_symbol(r) - self.assertIsInstance(r_disc.children[1], pybamm.Vector) + self.assertIsInstance(r_disc, pybamm.Vector) np.testing.assert_array_equal( r_disc.evaluate(), 3 * disc.mesh["negative particle"][0].nodes[:, np.newaxis],