Skip to content

Commit

Permalink
Merge pull request #1912 from pybamm-team/remove-anytree
Browse files Browse the repository at this point in the history
Remove anytree
  • Loading branch information
valentinsulzer authored Jan 23, 2022
2 parents b86b0c7 + d8c069a commit dc4ff98
Show file tree
Hide file tree
Showing 14 changed files with 48 additions and 114 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
- Added an option to force install compatible versions of jax and jaxlib if already installed using CLI ([#1881](https://github.com/pybamm-team/PyBaMM/pull/1881))
- Allow pybamm.Solution.save_data() to return a string if filename is None, and added json to_format option ([#1909](https://github.com/pybamm-team/PyBaMM/pull/1909)

## Optimizations

- The `Symbol` nodes no longer subclasses `anytree.NodeMixIn`. This removes some checks that were not really needed ([#1912](https://github.com/pybamm-team/PyBaMM/pull/1912))

## Bug fixes

- Parameters can now be imported from any given path in `Windows` ([#1900](https://github.com/pybamm-team/PyBaMM/pull/1900))
Expand Down
14 changes: 5 additions & 9 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,18 +1043,14 @@ def _process_symbol(self, symbol):
# Return a new copy of the input parameter, but set the expected size
# according to the domain of the input parameter
expected_size = self._get_variable_size(symbol)
new_input_parameter = symbol.new_copy()
new_input_parameter.set_expected_size(expected_size)
new_input_parameter = pybamm.InputParameter(
symbol.name, symbol.domain, expected_size
)
return new_input_parameter

else:
# Backup option: return new copy of the object
try:
return symbol.new_copy()
except NotImplementedError:
raise NotImplementedError(
"Cannot discretise symbol of type '{}'".format(type(symbol))
)
# Backup option: return the object
return symbol

def concatenate(self, *symbols, sparse=False):
if sparse:
Expand Down
20 changes: 5 additions & 15 deletions pybamm/expression_tree/averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def x_average(symbol):
)
for domain in symbol.domains.values()
):
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol
# If symbol is a broadcast, reduce by one dimension
if isinstance(
symbol,
Expand Down Expand Up @@ -217,9 +215,7 @@ def z_average(symbol):
)
# If symbol doesn't have a domain, its average value is itself
if symbol.domain == []:
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol
# If symbol is a Broadcast, its average value is its child
elif isinstance(symbol, pybamm.Broadcast):
return symbol.orphans[0]
Expand Down Expand Up @@ -252,9 +248,7 @@ def yz_average(symbol):
)
# If symbol doesn't have a domain, its average value is itself
if symbol.domain == []:
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol
# If symbol is a Broadcast, its average value is its child
elif isinstance(symbol, pybamm.Broadcast):
return symbol.orphans[0]
Expand Down Expand Up @@ -287,9 +281,7 @@ def r_average(symbol):
["negative particle"],
["working particle"],
]:
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol
# If symbol is a secondary broadcast onto "negative electrode" or
# "positive electrode", take the r-average of the child then broadcast back
elif isinstance(symbol, pybamm.SecondaryBroadcast) and symbol.domains[
Expand Down Expand Up @@ -334,9 +326,7 @@ def size_average(symbol, f_a_dist=None):
domain in [["negative particle size"], ["positive particle size"]]
for domain in list(symbol.domains.values())
):
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol

# If symbol is a primary broadcast to "particle size", take the orphan
elif isinstance(symbol, pybamm.PrimaryBroadcast) and symbol.domain in [
Expand Down
8 changes: 4 additions & 4 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __str__(self):

def _diff(self, variable):
"""See :meth:`pybamm.Symbol._diff()`."""
children_diffs = [child.diff(variable) for child in self.cached_children]
children_diffs = [child.diff(variable) for child in self.children]
if len(children_diffs) == 1:
diff = children_diffs[0]
else:
Expand Down Expand Up @@ -92,7 +92,7 @@ def _concatenation_evaluate(self, children_eval):

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
"""See :meth:`pybamm.Symbol.evaluate()`."""
children = self.cached_children
children = self.children
if known_evals is not None:
if self.id not in known_evals:
children_eval = [None] * len(children)
Expand Down Expand Up @@ -189,7 +189,7 @@ def __init__(self, *children):

def _concatenation_jac(self, children_jacs):
"""See :meth:`pybamm.Concatenation.concatenation_jac()`."""
children = self.cached_children
children = self.children
if len(children) == 0:
return pybamm.Scalar(0)
else:
Expand Down Expand Up @@ -252,7 +252,7 @@ def __init__(self, children, full_mesh, copy_this=None):

# create disc of domain => slice for each child
self._children_slices = [
self.create_slices(child) for child in self.cached_children
self.create_slices(child) for child in self.children
]
else:
self._full_mesh = copy.copy(copy_this._full_mesh)
Expand Down
21 changes: 7 additions & 14 deletions pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,21 @@ class InputParameter(pybamm.Symbol):
domain : iterable of str, or str
list of domains over which the node is valid (empty list indicates the symbol
is valid over all domains)
expected_size : int
The size of the input parameter expected, defaults to 1 (scalar input)
"""

def __init__(self, name, domain=None):
# Expected shape defaults to 1
self._expected_size = 1
def __init__(self, name, domain=None, expected_size=1):
self._expected_size = expected_size
super().__init__(name, domain=domain)

def create_copy(self):
"""See :meth:`pybamm.Symbol.new_copy()`."""
new_input_parameter = InputParameter(self.name, self.domain)
new_input_parameter._expected_size = self._expected_size
new_input_parameter = InputParameter(
self.name, self.domain, expected_size=self._expected_size
)
return new_input_parameter

def set_expected_size(self, size):
"""Specify the size that the input parameter should be."""
self._expected_size = size

# We also need to update the saved size and shape
self._saved_size = size
self._saved_shape = (size, 1)
self._saved_evaluate_for_shape = self._evaluate_for_shape()

def _evaluate_for_shape(self):
"""
Returns the scalar 'NaN' to represent the shape of a parameter.
Expand Down
6 changes: 2 additions & 4 deletions pybamm/expression_tree/operations/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def jac(self, symbol, variable):
return jac

def _jac(self, symbol, variable):
""" See :meth:`Jacobian.jac()`. """
"""See :meth:`Jacobian.jac()`."""

if isinstance(symbol, pybamm.BinaryOperator):
left, right = symbol.children
Expand All @@ -76,9 +76,7 @@ def _jac(self, symbol, variable):
jac = symbol._function_jac(children_jacs)

elif isinstance(symbol, pybamm.Concatenation):
children_jacs = [
self.jac(child, variable) for child in symbol.cached_children
]
children_jacs = [self.jac(child, variable) for child in symbol.children]
if len(children_jacs) == 1:
jac = children_jacs[0]
else:
Expand Down
5 changes: 2 additions & 3 deletions pybamm/expression_tree/operations/replace_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,5 @@ def _process_symbol(self, symbol):
else:
# Only other option is that the symbol is a leaf (doesn't have children)
# In this case, since we have already ruled out that the symbol is one of
# the symbols that needs to be replaced, we can just return a new copy of
# the symbol
return symbol.new_copy()
# the symbols that needs to be replaced, we can just return the symbol
return symbol
18 changes: 4 additions & 14 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#
# Base Symbol Class for the expression tree
#
import copy
import numbers

import anytree
Expand Down Expand Up @@ -179,7 +178,7 @@ def simplify_if_constant(symbol):
return symbol


class Symbol(anytree.NodeMixin):
class Symbol:
"""
Base node class for the expression tree.
Expand Down Expand Up @@ -210,19 +209,10 @@ def __init__(self, name, children=None, domain=None, auxiliary_domains=None):
if children is None:
children = []

# Store "orphans", which are separate from children as they do not have a
# parent node, so they do not cause tree corruption errors when used again
# in a different part of the tree
self._children = children
# Keep a separate "oprhans" attribute for backwards compatibility
self._orphans = children

for child in children:
# copy child before adding
# this also adds copy.copy(child) to self.children
copy.copy(child).parent = self

# cache children
self.cached_children = super(Symbol, self).children

# Set auxiliary domains
self._domains = {"primary": None}
self.auxiliary_domains = auxiliary_domains
Expand Down Expand Up @@ -250,7 +240,7 @@ def children(self):
Note: it is assumed that children of a node are not modified after initial
creation
"""
return self.cached_children
return self._children

@property
def name(self):
Expand Down
10 changes: 3 additions & 7 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,11 @@ def __init__(self, child):

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
child = self.child.new_copy()
return sign(child) * child.diff(variable)
return sign(self.child) * self.child.diff(variable)

def _unary_jac(self, child_jac):
"""See :meth:`pybamm.UnaryOperator._unary_jac()`."""
child = self.child.new_copy()
return sign(child) * child_jac
return sign(self.child) * child_jac

def _unary_evaluate(self, child):
"""See :meth:`UnaryOperator._unary_evaluate()`."""
Expand Down Expand Up @@ -1273,9 +1271,7 @@ def boundary_value(symbol, side):

# If symbol doesn't have a domain, its boundary value is itself
if symbol.domain == []:
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol
# If symbol is a primary or full broadcast, reduce by one dimension
if isinstance(symbol, (pybamm.PrimaryBroadcast, pybamm.FullBroadcast)):
return symbol.reduce_one_dimension()
Expand Down
13 changes: 3 additions & 10 deletions pybamm/parameters/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def _process_symbol(self, symbol):
):
# Wrap with NotConstant to avoid simplification,
# which would stop symbolic diff from working properly
new_child = pybamm.NotConstant(child.new_copy())
new_child = pybamm.NotConstant(child)
new_children.append(self.process_symbol(new_child))
else:
new_children.append(self.process_symbol(child))
Expand Down Expand Up @@ -766,15 +766,8 @@ def _process_symbol(self, symbol):
return symbol._concatenation_new_copy(new_children)

else:
# Backup option: return new copy of the object
try:
return symbol.new_copy()
except NotImplementedError:
raise NotImplementedError(
"Cannot process parameters for symbol of type '{}'".format(
type(symbol)
)
)
# Backup option: return the object
return symbol

def evaluate(self, symbol):
"""
Expand Down
18 changes: 5 additions & 13 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,23 +480,17 @@ def jacp(*args, **kwargs):
found_t = True
# Dimensional
elif symbol.right.id == (pybamm.t * model.timescale_eval).id:
expr = (
symbol.left.new_copy() / symbol.right.right.new_copy()
)
expr = symbol.left / symbol.right.right
found_t = True
elif symbol.left.id == (pybamm.t * model.timescale_eval).id:
expr = (
symbol.right.new_copy() / symbol.left.right.new_copy()
)
expr = symbol.right / symbol.left.right
found_t = True

# Update the events if the heaviside function depended on t
if found_t:
model.events.append(
pybamm.Event(
str(symbol),
expr.new_copy(),
pybamm.EventType.DISCONTINUITY,
str(symbol), expr, pybamm.EventType.DISCONTINUITY
)
)
elif isinstance(symbol, pybamm.Modulo):
Expand All @@ -507,9 +501,7 @@ def jacp(*args, **kwargs):
found_t = True
# Dimensional
elif symbol.left.id == (pybamm.t * model.timescale_eval).id:
expr = (
symbol.right.new_copy() / symbol.left.right.new_copy()
)
expr = symbol.right / symbol.left.right
found_t = True

# Update the events if the modulo function depended on t
Expand All @@ -523,7 +515,7 @@ def jacp(*args, **kwargs):
model.events.append(
pybamm.Event(
str(symbol),
expr.new_copy() * pybamm.Scalar(i + 1),
expr * pybamm.Scalar(i + 1),
pybamm.EventType.DISCONTINUITY,
)
)
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/test_expression_tree/test_input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ def test_input_parameter_init(self):
self.assertEqual(a.evaluate(inputs={"a": 1}), 1)
self.assertEqual(a.evaluate(inputs={"a": 5}), 5)

def test_set_expected_size(self):
a = pybamm.InputParameter("a")
a.set_expected_size(10)
a = pybamm.InputParameter("a", expected_size=10)
self.assertEqual(a._expected_size, 10)
np.testing.assert_array_equal(
a.evaluate(inputs="shape test"), np.nan * np.ones((10, 1))
)
y = np.linspace(0, 1, 10)
np.testing.assert_array_equal(a.evaluate(inputs={"a": y}), y[:, np.newaxis])

with self.assertRaisesRegex(
ValueError,
"Input parameter 'a' was given an object of size '1' but was expecting an "
Expand All @@ -34,7 +33,7 @@ def test_evaluate_for_shape(self):
self.assertTrue(np.isnan(a.evaluate_for_shape()))
self.assertEqual(a.shape, ())

a.set_expected_size(10)
a = pybamm.InputParameter("a", expected_size=10)
self.assertEqual(a.shape, (10, 1))
np.testing.assert_equal(a.evaluate_for_shape(), np.nan * np.ones((10, 1)))
self.assertEqual(a.evaluate_for_shape().shape, (10, 1))
Expand Down
Loading

0 comments on commit dc4ff98

Please sign in to comment.