Skip to content

Commit

Permalink
#709 fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Jan 31, 2020
1 parent 07d6c34 commit ce574ed
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 80 deletions.
24 changes: 12 additions & 12 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _function_diff(self, children, idx):

def arcsinh(child):
" Returns arcsinh function of child. "
return Arcsinh(child)
return pybamm.simplify_if_constant(Arcsinh(child), keep_domains=True)


class Cos(SpecificFunction):
Expand All @@ -276,7 +276,7 @@ def _function_diff(self, children, idx):

def cos(child):
" Returns cosine function of child. "
return Cos(child)
return pybamm.simplify_if_constant(Cos(child), keep_domains=True)


class Cosh(SpecificFunction):
Expand All @@ -292,7 +292,7 @@ def _function_diff(self, children, idx):

def cosh(child):
" Returns hyperbolic cosine function of child. "
return Cosh(child)
return pybamm.simplify_if_constant(Cosh(child), keep_domains=True)


class Exponential(SpecificFunction):
Expand All @@ -308,7 +308,7 @@ def _function_diff(self, children, idx):

def exp(child):
" Returns exponential function of child. "
return Exponential(child)
return pybamm.simplify_if_constant(Exponential(child), keep_domains=True)


class Log(SpecificFunction):
Expand All @@ -330,7 +330,7 @@ def _function_diff(self, children, idx):
def log(child, base="e"):
" Returns logarithmic function of child (any base, default 'e'). "
if base == "e":
return Log(child)
return pybamm.simplify_if_constant(Log(child), keep_domains=True)
else:
return Log(child) / np.log(base)

Expand All @@ -342,17 +342,17 @@ def log10(child):

def max(child):
" Returns max function of child. "
return Function(np.max, child)
return pybamm.simplify_if_constant(Function(np.max, child), keep_domains=True)


def min(child):
" Returns min function of child. "
return Function(np.min, child)
return pybamm.simplify_if_constant(Function(np.min, child), keep_domains=True)


def sech(child):
" Returns hyperbolic sec function of child. "
return 1 / Cosh(child)
return pybamm.simplify_if_constant(1 / Cosh(child), keep_domains=True)


class Sin(SpecificFunction):
Expand All @@ -368,7 +368,7 @@ def _function_diff(self, children, idx):

def sin(child):
" Returns sine function of child. "
return Sin(child)
return pybamm.simplify_if_constant(Sin(child), keep_domains=True)


class Sinh(SpecificFunction):
Expand All @@ -384,7 +384,7 @@ def _function_diff(self, children, idx):

def sinh(child):
" Returns hyperbolic sine function of child. "
return Sinh(child)
return pybamm.simplify_if_constant(Sinh(child), keep_domains=True)


class Sqrt(SpecificFunction):
Expand All @@ -405,7 +405,7 @@ def _function_diff(self, children, idx):

def sqrt(child):
" Returns square root function of child. "
return Sqrt(child)
return pybamm.simplify_if_constant(Sqrt(child), keep_domains=True)


class Tanh(SpecificFunction):
Expand All @@ -421,4 +421,4 @@ def _function_diff(self, children, idx):

def tanh(child):
" Returns hyperbolic tan function of child. "
return Tanh(child)
return pybamm.simplify_if_constant(Tanh(child), keep_domains=True)
5 changes: 4 additions & 1 deletion pybamm/expression_tree/operations/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import numbers
from scipy.sparse import issparse
from scipy.sparse import issparse, csr_matrix


def simplify_if_constant(symbol, keep_domains=False):
Expand All @@ -32,6 +32,9 @@ def simplify_if_constant(symbol, keep_domains=False):
result, domain=domain, auxiliary_domains=auxiliary_domains
)
else:
# Turn matrix of zeros into sparse matrix
if isinstance(result, np.ndarray) and np.all(result == 0):
result = csr_matrix(result)
return pybamm.Matrix(
result, domain=domain, auxiliary_domains=auxiliary_domains
)
Expand Down
100 changes: 56 additions & 44 deletions tests/unit/test_expression_tree/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_multi_var_function_cube(arg1, arg2):
class TestFunction(unittest.TestCase):
def test_number_input(self):
# with numbers
log = pybamm.log(10)
log = pybamm.Function(np.log, 10)
self.assertIsInstance(log.children[0], pybamm.Scalar)
self.assertEqual(log.evaluate(), np.log(10))

Expand Down Expand Up @@ -127,27 +127,29 @@ def test_function_unnamed(self):

class TestSpecificFunctions(unittest.TestCase):
def test_arcsinh(self):
a = pybamm.Scalar(3)
a = pybamm.InputParameter("a")
fun = pybamm.arcsinh(a)
self.assertIsInstance(fun, pybamm.Arcsinh)
self.assertEqual(fun.evaluate(), np.arcsinh(3))
self.assertEqual(fun.evaluate(u={"a": 3}), np.arcsinh(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(),
(pybamm.arcsinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
fun.diff(a).evaluate(u={"a": 3}),
(pybamm.arcsinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
/ h,
places=5,
)

def test_cos(self):
a = pybamm.Scalar(3)
a = pybamm.InputParameter("a")
fun = pybamm.cos(a)
self.assertIsInstance(fun, pybamm.Cos)
self.assertEqual(fun.children[0].id, a.id)
self.assertEqual(fun.evaluate(), np.cos(3))
self.assertEqual(fun.evaluate(u={"a": 3}), np.cos(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(),
(pybamm.cos(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
fun.diff(a).evaluate(u={"a": 3}),
(pybamm.cos(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
/ h,
places=5,
)

Expand All @@ -157,110 +159,120 @@ def test_cos(self):
self.assertEqual(fun.id, fun.simplify().id)

def test_cosh(self):
a = pybamm.Scalar(3)
a = pybamm.InputParameter("a")
fun = pybamm.cosh(a)
self.assertIsInstance(fun, pybamm.Cosh)
self.assertEqual(fun.children[0].id, a.id)
self.assertEqual(fun.evaluate(), np.cosh(3))
self.assertEqual(fun.evaluate(u={"a": 3}), np.cosh(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(),
(pybamm.cosh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
fun.diff(a).evaluate(u={"a": 3}),
(pybamm.cosh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
/ h,
places=5,
)

def test_exp(self):
a = pybamm.Scalar(3)
a = pybamm.InputParameter("a")
fun = pybamm.exp(a)
self.assertIsInstance(fun, pybamm.Exponential)
self.assertEqual(fun.children[0].id, a.id)
self.assertEqual(fun.evaluate(), np.exp(3))
self.assertEqual(fun.evaluate(u={"a": 3}), np.exp(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(),
(pybamm.exp(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
fun.diff(a).evaluate(u={"a": 3}),
(pybamm.exp(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
/ h,
places=5,
)

def test_log(self):
a = pybamm.Scalar(3)
a = pybamm.InputParameter("a")
fun = pybamm.log(a)
self.assertEqual(fun.evaluate(), np.log(3))
self.assertEqual(fun.evaluate(u={"a": 3}), np.log(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(),
(pybamm.log(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
fun.diff(a).evaluate(u={"a": 3}),
(pybamm.log(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
/ h,
places=5,
)

# Base 10
fun = pybamm.log10(a)
self.assertEqual(fun.evaluate(), np.log10(3))
self.assertEqual(fun.evaluate(u={"a": 3}), np.log10(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(),
(pybamm.log10(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
fun.diff(a).evaluate(u={"a": 3}),
(pybamm.log10(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
/ h,
places=5,
)

def test_max(self):
a = pybamm.Vector(np.array([1, 2, 3]))
a = pybamm.StateVector(slice(0, 3))
y_test = np.array([1, 2, 3])
fun = pybamm.max(a)
self.assertIsInstance(fun, pybamm.Function)
self.assertEqual(fun.evaluate(), 3)
self.assertEqual(fun.evaluate(y=y_test), 3)

def test_min(self):
a = pybamm.Vector(np.array([1, 2, 3]))
a = pybamm.StateVector(slice(0, 3))
y_test = np.array([1, 2, 3])
fun = pybamm.min(a)
self.assertIsInstance(fun, pybamm.Function)
self.assertEqual(fun.evaluate(), 1)
self.assertEqual(fun.evaluate(y=y_test), 1)

def test_sin(self):
a = pybamm.Scalar(3)
a = pybamm.InputParameter("a")
fun = pybamm.sin(a)
self.assertIsInstance(fun, pybamm.Sin)
self.assertEqual(fun.children[0].id, a.id)
self.assertEqual(fun.evaluate(), np.sin(3))
self.assertEqual(fun.evaluate(u={"a": 3}), np.sin(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(),
(pybamm.sin(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
fun.diff(a).evaluate(u={"a": 3}),
(pybamm.sin(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
/ h,
places=5,
)

def test_sinh(self):
a = pybamm.Scalar(3)
a = pybamm.InputParameter("a")
fun = pybamm.sinh(a)
self.assertIsInstance(fun, pybamm.Sinh)
self.assertEqual(fun.children[0].id, a.id)
self.assertEqual(fun.evaluate(), np.sinh(3))
self.assertEqual(fun.evaluate(u={"a": 3}), np.sinh(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(),
(pybamm.sinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
fun.diff(a).evaluate(u={"a": 3}),
(pybamm.sinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
/ h,
places=5,
)

def test_sqrt(self):
a = pybamm.Scalar(3)
a = pybamm.InputParameter("a")
fun = pybamm.sqrt(a)
self.assertIsInstance(fun, pybamm.Sqrt)
self.assertEqual(fun.evaluate(), np.sqrt(3))
self.assertEqual(fun.evaluate(u={"a": 3}), np.sqrt(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(),
(pybamm.sqrt(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
fun.diff(a).evaluate(u={"a": 3}),
(pybamm.sqrt(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
/ h,
places=5,
)

def test_tanh(self):
a = pybamm.Scalar(3)
a = pybamm.InputParameter("a")
fun = pybamm.tanh(a)
self.assertEqual(fun.evaluate(), np.tanh(3))
self.assertEqual(fun.evaluate(u={"a": 3}), np.tanh(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(),
(pybamm.tanh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
fun.diff(a).evaluate(u={"a": 3}),
(pybamm.tanh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
/ h,
places=5,
)

Expand Down
7 changes: 0 additions & 7 deletions tests/unit/test_expression_tree/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@ def test_matrix_operations(self):
(self.mat @ self.vect).evaluate(), np.array([[5], [2], [3]])
)

def test_matrix_modification(self):
exp = self.mat @ self.mat + self.mat
self.A[0, 0] = -1
self.assertTrue(exp.children[1]._entries[0, 0], -1)
self.assertTrue(exp.children[0].children[0]._entries[0, 0], -1)
self.assertTrue(exp.children[0].children[1]._entries[0, 0], -1)


class TestArray(unittest.TestCase):
def test_name(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def test_matrix_divide_simplify(self):
expr3 = (m / a).simplify()
self.assertIsInstance(expr3, pybamm.Matrix)
self.assertEqual(expr3.shape, m.shape)
np.testing.assert_array_equal(expr3.evaluate(), np.zeros((10, 10)))
np.testing.assert_array_equal(expr3.evaluate().toarray(), np.zeros((10, 10)))

def test_domain_concatenation_simplify(self):
# create discretisation
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/test_expression_tree/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@ def test_vector_operations(self):
(self.vect * self.vect).evaluate(), np.array([[1], [4], [9]])
)

def test_vector_modification(self):
exp = self.vect * self.vect + self.vect
self.x[0] = -1
self.assertTrue(exp.children[1]._entries[0], -1)
self.assertTrue(exp.children[0].children[0]._entries[0], -1)
self.assertTrue(exp.children[0].children[1]._entries[0], -1)

def test_wrong_size_entries(self):
with self.assertRaisesRegex(
ValueError, "Entries must have 1 dimension or be column vector"
Expand Down
1 change: 0 additions & 1 deletion tests/unit/test_parameters/test_geometric_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def test_macroscale_parameters(self):
self.assertEqual(
(L_n_eval + L_s_eval + L_p_eval).evaluate(), L_x_eval.evaluate()
)
self.assertEqual((L_n_eval + L_s_eval + L_p_eval).id, L_x_eval.id)
l_n_eval = parameter_values.process_symbol(l_n)
l_s_eval = parameter_values.process_symbol(l_s)
l_p_eval = parameter_values.process_symbol(l_p)
Expand Down
10 changes: 4 additions & 6 deletions tests/unit/test_parameters/test_parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,25 +279,23 @@ def test_process_function_parameter(self):
"const": 254,
}
)
a = pybamm.Parameter("a")
a = pybamm.InputParameter("a")

# process function
func = pybamm.FunctionParameter("func", a)
processed_func = parameter_values.process_symbol(func)
self.assertEqual(processed_func.evaluate(), 369)
self.assertEqual(processed_func.evaluate(u={"a": 3}), 369)

# process constant function
const = pybamm.FunctionParameter("const", a)
processed_const = parameter_values.process_symbol(const)
self.assertIsInstance(processed_const, pybamm.Multiplication)
self.assertIsInstance(processed_const.left, pybamm.Scalar)
self.assertIsInstance(processed_const.right, pybamm.Scalar)
self.assertIsInstance(processed_const, pybamm.Scalar)
self.assertEqual(processed_const.evaluate(), 254)

# process differentiated function parameter
diff_func = func.diff(a)
processed_diff_func = parameter_values.process_symbol(diff_func)
self.assertEqual(processed_diff_func.evaluate(), 123)
self.assertEqual(processed_diff_func.evaluate(u={"a": 3}), 123)

def test_process_inline_function_parameters(self):
def D(c):
Expand Down
Loading

0 comments on commit ce574ed

Please sign in to comment.