Skip to content

Commit

Permalink
Merge pull request #1789 from pybamm-team/issue-1768-concat
Browse files Browse the repository at this point in the history
Issue 1768 concat
  • Loading branch information
valentinsulzer authored Nov 8, 2021
2 parents d0712e2 + 2e37286 commit dd77bec
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 27 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
- Half-cell SPM and SPMe have been implemented ( [#1731](https://github.com/pybamm-team/PyBaMM/pull/1731))
## Bug fixes

- Fixed finite volume discretization in spherical polar coordinates ([#1782](https://github.com/pybamm-team/PyBaMM/pull/1782))
- Fixed `sympy` operators for `Arctan` and `Exponential` ([#1786](https://github.com/pybamm-team/PyBaMM/pull/1786))
- Fixed finite volume discretization in spherical polar coordinates ([#1782](https://github.com/pybamm-team/PyBaMM/pull/1782))

## Breaking changes

- Raise error if `Concatenation` is used directly with `Variable` objects (`concatenation` should be used instead) ([#1789](https://github.com/pybamm-team/PyBaMM/pull/1789))
# [v21.10](https://github.com/pybamm-team/PyBaMM/tree/v21.9) - 2021-10-31

## Features
Expand Down
28 changes: 19 additions & 9 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ class Concatenation(pybamm.Symbol):
"""

def __init__(self, *children, name=None, check_domain=True, concat_fun=None):
# The second condition checks whether this is the base Concatenation class
# or a subclass of Concatenation
# (ConcatenationVariable, NumpyConcatenation, ...)
if all(isinstance(child, pybamm.Variable) for child in children) and issubclass(
Concatenation, type(self)
):
raise TypeError(
"'ConcatenationVariable' should be used for concatenating 'Variable' "
"objects. We recommend using the 'concatenation' function, which will "
"automatically choose the best form."
)
if name is None:
name = "concatenation"
if check_domain:
Expand All @@ -46,10 +57,8 @@ def __str__(self):
return out

def _diff(self, variable):
""" See :meth:`pybamm.Symbol._diff()`. """
children_diffs = [
child.diff(variable) for child in self.cached_children
]
"""See :meth:`pybamm.Symbol._diff()`."""
children_diffs = [child.diff(variable) for child in self.cached_children]
if len(children_diffs) == 1:
diff = children_diffs[0]
else:
Expand Down Expand Up @@ -411,15 +420,17 @@ def simplified_concatenation(*children):
"""Perform simplifications on a concatenation."""
# remove children that are None
children = list(filter(lambda x: x is not None, children))
# Create Concatenation to easily read domains
concat = Concatenation(*children)
# Simplify concatenation of broadcasts all with the same child to a single
# broadcast across all domains
if len(children) == 0:
raise ValueError("Cannot create empty concatenation")
elif len(children) == 1:
return children[0]
elif all(isinstance(child, pybamm.Variable) for child in children):
return pybamm.ConcatenationVariable(*children)
else:
# Create Concatenation to easily read domains
concat = Concatenation(*children)
if all(
isinstance(child, pybamm.Broadcast)
and child.child.id == children[0].child.id
Expand All @@ -432,9 +443,8 @@ def simplified_concatenation(*children):
return pybamm.FullBroadcast(
unique_child, concat.domain, concat.auxiliary_domains
)
elif all(isinstance(child, pybamm.Variable) for child in children):
return pybamm.ConcatenationVariable(*children)
return concat
else:
return concat


def concatenation(*children):
Expand Down
37 changes: 21 additions & 16 deletions tests/integration/test_models/standard_model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class StandardModelTest(object):
""" Basic processing test for the models. """
"""Basic processing test for the models."""

def __init__(
self,
Expand Down Expand Up @@ -62,8 +62,9 @@ def test_processing_disc(self, disc=None):
# Model should still be well-posed after processing
self.model.check_well_posedness(post_discretisation=True)

def test_solving(self, solver=None, t_eval=None, inputs=None,
calculate_sensitivities=False):
def test_solving(
self, solver=None, t_eval=None, inputs=None, calculate_sensitivities=False
):
# Overwrite solver if given
if solver is not None:
self.solver = solver
Expand All @@ -82,7 +83,9 @@ def test_solving(self, solver=None, t_eval=None, inputs=None,
t_eval = np.linspace(0, 3600 / Crate, 100)

self.solution = self.solver.solve(
self.model, t_eval, inputs=inputs,
self.model,
t_eval,
inputs=inputs,
)

def test_outputs(self):
Expand All @@ -92,8 +95,9 @@ def test_outputs(self):
)
std_out_test.test_all()

def test_sensitivities(self, param_name, param_value,
output_name='Terminal voltage [V]'):
def test_sensitivities(
self, param_name, param_value, output_name="Terminal voltage [V]"
):

self.parameter_values.update({param_name: param_value})
Crate = abs(
Expand All @@ -114,8 +118,7 @@ def test_sensitivities(self, param_name, param_value,
self.solver.atol = 1e-8

self.solution = self.solver.solve(
self.model, t_eval, inputs=inputs,
calculate_sensitivities=True
self.model, t_eval, inputs=inputs, calculate_sensitivities=True
)
output_sens = self.solution[output_name].sensitivities[param_name]

Expand All @@ -124,18 +127,20 @@ def test_sensitivities(self, param_name, param_value,
inputs_plus = {param_name: (param_value + 0.5 * h)}
inputs_neg = {param_name: (param_value - 0.5 * h)}
sol_plus = self.solver.solve(
self.model, t_eval, inputs=inputs_plus,
self.model,
t_eval,
inputs=inputs_plus,
)
output_plus = sol_plus[output_name](t=t_eval)
sol_neg = self.solver.solve(
self.model, t_eval, inputs=inputs_neg
)
sol_neg = self.solver.solve(self.model, t_eval, inputs=inputs_neg)
output_neg = sol_neg[output_name](t=t_eval)
fd = ((np.array(output_plus) - np.array(output_neg)) / h)
fd = (np.array(output_plus) - np.array(output_neg)) / h
fd = fd.transpose().reshape(-1, 1)
np.testing.assert_allclose(
output_sens, fd,
rtol=1e-2, atol=1e-6,
output_sens,
fd,
rtol=1e-2,
atol=1e-6,
)

def test_all(
Expand All @@ -156,7 +161,7 @@ def test_all(


class OptimisationsTest(object):
""" Test that the optimised models give the same result as the original model. """
"""Test that the optimised models give the same result as the original model."""

def __init__(self, model, parameter_values=None, disc=None):
# Set parameter values
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test_expression_tree/test_concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def test_base_concatenation(self):
# concatenation of lenght 1
self.assertEqual(pybamm.concatenation(a), a)

a = pybamm.Variable("a", domain="test a")
b = pybamm.Variable("b", domain="test b")
with self.assertRaisesRegex(TypeError, "ConcatenationVariable"):
pybamm.Concatenation(a, b)

def test_concatenation_domains(self):
a = pybamm.Symbol("a", domain=["negative electrode"])
b = pybamm.Symbol("b", domain=["separator", "positive electrode"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def test_domain_concatenation_2D(self):
np.testing.assert_allclose(result, expr.evaluate(y=y))

# check that concatenating a single domain is consistent
expr = disc.process_symbol(pybamm.Concatenation(a))
expr = disc.process_symbol(pybamm.concatenation(a))
evaluator = pybamm.EvaluatorPython(expr)
result = evaluator.evaluate(y=y)
np.testing.assert_allclose(result, expr.evaluate(y=y))
Expand Down

0 comments on commit dd77bec

Please sign in to comment.