Skip to content

Commit

Permalink
Merge pull request #1537 from pybamm-team/double-negation
Browse files Browse the repository at this point in the history
double negation case
  • Loading branch information
valentinsulzer committed Jul 14, 2021
2 parents d62c862 + e78f727 commit bae3720
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 3 deletions.
8 changes: 7 additions & 1 deletion pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,10 @@ def simplified_multiplication(left, right):
return (left * r_left) + (left * r_right)

# Negation simplifications
if isinstance(left, pybamm.Negate) and right.is_constant():
if isinstance(left, pybamm.Negate) and isinstance(right, pybamm.Negate):
# Double negation cancels out
return left.orphans[0] * right.orphans[0]
elif isinstance(left, pybamm.Negate) and right.is_constant():
# Simplify (-a) * b to a * (-b) if (-b) is constant
return left.orphans[0] * (-right)
elif isinstance(right, pybamm.Negate) and left.is_constant():
Expand Down Expand Up @@ -1193,6 +1196,9 @@ def simplified_division(left, right):
return l_left * new_right

# Negation simplifications
if isinstance(left, pybamm.Negate) and isinstance(right, pybamm.Negate):
# Double negation cancels out
return left.orphans[0] / right.orphans[0]
elif isinstance(left, pybamm.Negate) and right.is_constant():
# Simplify (-a) / b to a / (-b) if (-b) is constant
return left.orphans[0] / (-right)
Expand Down
Empty file.
11 changes: 9 additions & 2 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,8 +1116,15 @@ def div(symbol):
# Divergence commutes with Negate operator
if isinstance(symbol, pybamm.Negate):
return -div(symbol.orphans[0])
else:
return Divergence(symbol)
elif isinstance(symbol, (pybamm.Multiplication, pybamm.Division)):
left, right = symbol.orphans
if isinstance(left, pybamm.Negate):
return -div(symbol._binary_new_copy(left.orphans[0], right))
# elif isinstance(right, pybamm.Negate):
# return -div(symbol._binary_new_copy(left, right.orphans[0]))

# Last resort
return Divergence(symbol)


def laplacian(symbol):
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def test_binary_simplifications(self):
self.assertEqual((c * -1).id, (-c).id)
self.assertEqual((-1 * c).id, (-c).id)
# multiplication with a negation
self.assertEqual((-c * -f).id, (c * f).id)
self.assertEqual((-c * 4).id, (c * -4).id)
self.assertEqual((4 * -c).id, (-4 * c).id)
# multiplication with broadcasts
Expand Down Expand Up @@ -532,6 +533,7 @@ def test_binary_simplifications(self):
self.assertEqual((c / c).id, pybamm.Scalar(1).id)
self.assertEqual((broad2 / broad2).id, broad1.id)
# division with a negation
self.assertEqual((-c / -f).id, (c / f).id)
self.assertEqual((-c / 4).id, (c / -4).id)
self.assertEqual((4 / -c).id, (-4 / c).id)
# division with broadcasts
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_expression_tree/test_unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ def test_div(self):
div = pybamm.div(-pybamm.Gradient(a))
self.assertEqual(div.id, (-pybamm.Divergence(pybamm.Gradient(a))).id)

div = pybamm.div(-a * pybamm.Gradient(a))
self.assertEqual(div.id, (-pybamm.Divergence(a * pybamm.Gradient(a))).id)

# div = pybamm.div(a * -pybamm.Gradient(a))
# self.assertEqual(div.id, (-pybamm.Divergence(a * pybamm.Gradient(a))).id)

def test_integral(self):
# space integral
a = pybamm.Symbol("a", domain=["negative electrode"])
Expand Down

0 comments on commit bae3720

Please sign in to comment.