Skip to content

Commit

Permalink
#609 add broadcast to edges and pass tests
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 14, 2020
1 parent 1c1d864 commit 59a989c
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 146 deletions.
11 changes: 11 additions & 0 deletions docs/source/expression_tree/broadcasts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,14 @@ Broadcasting Operators

.. autoclass:: pybamm.SecondaryBroadcast
:members:

.. autoclass:: pybamm.FullBroadcastToEdges
:members:

.. autoclass:: pybamm.PrimaryBroadcastToEdges
:members:

.. autoclass:: pybamm.SecondaryBroadcastToEdges
:members:

.. autofunction:: pybamm.ones_like
4 changes: 2 additions & 2 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def set_variable_slices(self, variables):
for child, mesh in meshes.items():
for domain_mesh in mesh:
submesh = domain_mesh[i]
end += submesh.npts_for_broadcast
end += submesh.npts_for_broadcast_to_nodes
y_slices[child.id].append(slice(start, end))
start = end
else:
Expand All @@ -249,7 +249,7 @@ def _get_variable_size(self, variable):
size = 0
for dom in variable.domain:
for submesh in self.spatial_methods[dom].mesh[dom]:
size += submesh.npts_for_broadcast
size += submesh.npts_for_broadcast_to_nodes
return size

def _preprocess_external_variables(self, model):
Expand Down
59 changes: 51 additions & 8 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
child,
broadcast_domain,
broadcast_auxiliary_domains=None,
broadcast_type="full",
broadcast_type="full to nodes",
name=None,
):
# Convert child to scalar if it is a number
Expand Down Expand Up @@ -84,7 +84,9 @@ class PrimaryBroadcast(Broadcast):
"""

def __init__(self, child, broadcast_domain, name=None):
super().__init__(child, broadcast_domain, broadcast_type="primary", name=name)
super().__init__(
child, broadcast_domain, broadcast_type="primary to nodes", name=name
)

def check_and_set_domains(
self, child, broadcast_type, broadcast_domain, broadcast_auxiliary_domains
Expand Down Expand Up @@ -127,8 +129,8 @@ def check_and_set_domains(
return domain, auxiliary_domains

def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return PrimaryBroadcast(child, self.broadcast_domain)
""" See :meth:`pybamm.UnaryOperator._unary_new_copy()`. """
return self.__class__(child, self.broadcast_domain)

def _evaluate_for_shape(self):
"""
Expand All @@ -140,6 +142,18 @@ def _evaluate_for_shape(self):
return np.outer(child_eval, vec).reshape(-1, 1)


class PrimaryBroadcastToEdges(PrimaryBroadcast):
"A primary broadcast onto the edges of the domain"

def __init__(self, child, broadcast_domain, name=None):
name = name or "broadcast to edges"
super().__init__(child, broadcast_domain, name)
self.broadcast_type = "primary to edges"

def evaluates_on_edges(self):
return True


class SecondaryBroadcast(Broadcast):
"""A node in the expression tree representing a primary broadcasting operator.
Broadcasts in a `secondary` dimension only. That is, makes explicit copies of the
Expand All @@ -162,7 +176,9 @@ class SecondaryBroadcast(Broadcast):
"""

def __init__(self, child, broadcast_domain, name=None):
super().__init__(child, broadcast_domain, broadcast_type="secondary", name=name)
super().__init__(
child, broadcast_domain, broadcast_type="secondary to nodes", name=name
)

def check_and_set_domains(
self, child, broadcast_type, broadcast_domain, broadcast_auxiliary_domains
Expand Down Expand Up @@ -207,7 +223,7 @@ def check_and_set_domains(
return domain, auxiliary_domains

def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
""" See :meth:`pybamm.UnaryOperator._unary_new_copy()`. """
return SecondaryBroadcast(child, self.broadcast_domain)

def _evaluate_for_shape(self):
Expand All @@ -220,6 +236,18 @@ def _evaluate_for_shape(self):
return np.outer(vec, child_eval).reshape(-1, 1)


class SecondaryBroadcastToEdges(SecondaryBroadcast):
"A secondary broadcast onto the edges of a domain"

def __init__(self, child, broadcast_domain, name=None):
name = name or "broadcast to edges"
super().__init__(child, broadcast_domain, name)
self.broadcast_type = "secondary to edges"

def evaluates_on_edges(self):
return True


class FullBroadcast(Broadcast):
"A class for full broadcasts"

Expand All @@ -230,7 +258,7 @@ def __init__(self, child, broadcast_domain, auxiliary_domains, name=None):
child,
broadcast_domain,
broadcast_auxiliary_domains=auxiliary_domains,
broadcast_type="full",
broadcast_type="full to nodes",
name=name,
)

Expand All @@ -250,7 +278,7 @@ def check_and_set_domains(
return domain, auxiliary_domains

def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
""" See :meth:`pybamm.UnaryOperator._unary_new_copy()`. """
return FullBroadcast(child, self.broadcast_domain, self.auxiliary_domains)

def _evaluate_for_shape(self):
Expand All @@ -266,6 +294,21 @@ def _evaluate_for_shape(self):
return child_eval * vec


class FullBroadcastToEdges(FullBroadcast):
"""
A full broadcast onto the edges of a domain (edges of primary dimension, nodes of
other dimensions)
"""

def __init__(self, child, broadcast_domain, auxiliary_domains, name=None):
name = name or "broadcast to edges"
super().__init__(child, broadcast_domain, auxiliary_domains, name)
self.broadcast_type = "full to edges"

def evaluates_on_edges(self):
return True


def ones_like(*symbols):
"""
Create a symbol with the same shape as the input symbol and with constant value '1',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _get_neg_pos_coupled_variables(self, variables):
phi_e = phi_s - delta_phi

variables.update(self._get_domain_potential_variables(phi_e))

variables.update({"test": pybamm.x_average(phi_s)})
return variables

def _get_sep_coupled_variables(self, variables):
Expand Down
2 changes: 1 addition & 1 deletion pybamm/spatial_methods/finite_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def build(self, mesh):
# add npts_for_broadcast to mesh domains for this particular discretisation
for dom in mesh.keys():
for i in range(len(mesh[dom])):
mesh[dom][i].npts_for_broadcast = mesh[dom][i].npts
mesh[dom][i].npts_for_broadcast_to_nodes = mesh[dom][i].npts

def spatial_variable(self, symbol):
"""
Expand Down
2 changes: 1 addition & 1 deletion pybamm/spatial_methods/scikit_finite_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def build(self, mesh):
# add npts_for_broadcast to mesh domains for this particular discretisation
for dom in mesh.keys():
for i in range(len(mesh[dom])):
mesh[dom][i].npts_for_broadcast = mesh[dom][i].npts
mesh[dom][i].npts_for_broadcast_to_nodes = mesh[dom][i].npts

def spatial_variable(self, symbol):
"""
Expand Down
31 changes: 19 additions & 12 deletions pybamm/spatial_methods/spatial_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def build(self, mesh):
# add npts_for_broadcast to mesh domains for this particular discretisation
for dom in mesh.keys():
for i in range(len(mesh[dom])):
mesh[dom][i].npts_for_broadcast = mesh[dom][i].npts
mesh[dom][i].npts_for_broadcast_to_nodes = mesh[dom][i].npts
self._mesh = mesh

@property
Expand Down Expand Up @@ -81,7 +81,8 @@ def broadcast(self, symbol, domain, auxiliary_domains, broadcast_type):
domain : iterable of strings
The domain to broadcast to
broadcast_type : str
The type of broadcast, either: 'primary' or 'full'
The type of broadcast: 'primary to node', 'primary to edges', 'secondary to
nodes', 'secondary to edges', 'full to nodes' or 'full to edges'
Returns
-------
Expand All @@ -90,14 +91,24 @@ def broadcast(self, symbol, domain, auxiliary_domains, broadcast_type):
"""

primary_domain_size = sum(
self.mesh[dom][0].npts_for_broadcast for dom in domain
self.mesh[dom][0].npts_for_broadcast_to_nodes for dom in domain
)

secondary_domain_size = sum(
self.mesh[dom][0].npts_for_broadcast_to_nodes
for dom in auxiliary_domains.get("secondary", [])
) # returns empty list if auxiliary_domains doesn't have "secondary" key
full_domain_size = sum(
subdom.npts_for_broadcast for dom in domain for subdom in self.mesh[dom]
subdom.npts_for_broadcast_to_nodes
for dom in domain
for subdom in self.mesh[dom]
)
if broadcast_type.endswith("to edges"):
# add one point to each domain for broadcasting to edges
primary_domain_size += 1
secondary_domain_size += 1
full_domain_size += 1

if broadcast_type == "primary":
if broadcast_type.startswith("primary"):
# Make copies of the child stacked on top of each other
sub_vector = np.ones((primary_domain_size, 1))
if symbol.shape_for_testing == ():
Expand All @@ -107,11 +118,7 @@ def broadcast(self, symbol, domain, auxiliary_domains, broadcast_type):
matrix = csr_matrix(kron(eye(symbol.shape_for_testing[0]), sub_vector))
out = pybamm.Matrix(matrix) @ symbol
out.domain = domain
elif broadcast_type == "secondary":
secondary_domain_size = sum(
self.mesh[dom][0].npts_for_broadcast
for dom in auxiliary_domains["secondary"]
)
elif broadcast_type.startswith("secondary"):
kron_size = full_domain_size // primary_domain_size
# Symbol may be on edges so need to calculate size carefully
symbol_primary_size = symbol.shape[0] // kron_size
Expand All @@ -121,7 +128,7 @@ def broadcast(self, symbol, domain, auxiliary_domains, broadcast_type):
# Repeat for secondary points
matrix = csr_matrix(kron(eye(kron_size), sub_matrix))
out = pybamm.Matrix(matrix) @ symbol
elif broadcast_type == "full":
elif broadcast_type.startswith("full"):
out = symbol * pybamm.Vector(np.ones(full_domain_size), domain=domain)

out.auxiliary_domains = auxiliary_domains
Expand Down
Loading

0 comments on commit 59a989c

Please sign in to comment.