Skip to content

Commit

Permalink
Firedrake constants change (#98)
Browse files Browse the repository at this point in the history
Also  Update Hessian mixed convergence test code

---------

Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk>
Co-authored-by: David A. Ham <david.ham@imperial.ac.uk>
  • Loading branch information
3 people authored Jun 14, 2023
1 parent 9e1725a commit 66567ba
Show file tree
Hide file tree
Showing 14 changed files with 130 additions and 127 deletions.
14 changes: 7 additions & 7 deletions dolfin_adjoint_common/blocks/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepar
output = self.compat.assemble_adjoint_value(dform)
return [[adj_input * output, V]]

if isinstance(c, self.backend.Function):
space = c.function_space()
elif isinstance(c, self.backend.Constant):
if self.compat.isconstant(c):
mesh = self.compat.extract_mesh_from_form(self.form)
space = c._ad_function_space(mesh)
elif isinstance(c, self.backend.Function):
space = c.function_space()
elif isinstance(c, self.compat.MeshType):
c_rep = self.backend.SpatialCoordinate(c_rep)
space = c._ad_function_space()
Expand Down Expand Up @@ -146,14 +146,14 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_v
c1 = block_variable.output
c1_rep = block_variable.saved_output

if isinstance(c1, self.backend.Function):
if self.compat.isconstant(c1):
mesh = self.compat.extract_mesh_from_form(form)
space = c1._ad_function_space(mesh)
elif isinstance(c1, self.backend.Function):
space = c1.function_space()
elif isinstance(c1, self.compat.ExpressionType):
mesh = form.ufl_domain().ufl_cargo()
space = c1._ad_function_space(mesh)
elif isinstance(c1, self.backend.Constant):
mesh = self.compat.extract_mesh_from_form(form)
space = c1._ad_function_space(mesh)
elif isinstance(c1, self.compat.MeshType):
c1_rep = self.backend.SpatialCoordinate(c1)
space = c1._ad_function_space()
Expand Down
4 changes: 2 additions & 2 deletions dolfin_adjoint_common/blocks/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar
values = numpy.zeros(self.value.shape)
for i, tlm_input in enumerate(tlm_inputs):
values.flat[self.dependency_to_index[i]] = tlm_input
elif isinstance(values, self.backend.Constant):
elif self.compat.isconstant(values):
values = values.values()
return constant_from_values(block_variable.output, values)

Expand All @@ -78,6 +78,6 @@ def recompute_component(self, inputs, block_variable, idx, prepared):
for i, inp in enumerate(inputs):
self.value[self.dependency_to_index[i]] = inp
values = self.value
elif isinstance(values, self.backend.Constant):
elif self.compat.isconstant(values):
values = values.values()
return constant_from_values(block_variable.output, values)
2 changes: 1 addition & 1 deletion dolfin_adjoint_common/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepar
adj_inputs = adj_inputs[0]
adj_output = None
for adj_input in adj_inputs:
if isinstance(c, self.backend.Constant):
if self.compat.isconstant(c):
adj_value = self.backend.Function(self.parent_space)
adj_input.apply(adj_value.vector())
if self.function_space != self.parent_space:
Expand Down
25 changes: 18 additions & 7 deletions dolfin_adjoint_common/blocks/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ def __init__(self, func, other, ad_block_tag=None):
super().__init__(ad_block_tag=ad_block_tag)
self.other = None
self.expr = None
if isinstance(other, float) or isinstance(other, int):
other = AdjFloat(other)
if isinstance(other, OverloadedType):
self.add_dependency(other, no_duplicates=True)
elif isinstance(other, float) or isinstance(other, int):
other = AdjFloat(other)
self.add_dependency(other, no_duplicates=True)
elif not (isinstance(other, float) or isinstance(other, int)):
# Assume that this is a point-wise evaluated UFL expression (firedrake only)
for op in traverse_unique_terminals(other):
Expand Down Expand Up @@ -43,8 +44,13 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
if self.expr is None:
if isinstance(block_variable.output, AdjFloat):
return adj_inputs[0].sum()
elif isinstance(block_variable.output, self.backend.Constant):
try:
# Adjoint of a broadcast is just a sum
return adj_inputs[0].sum()
except AttributeError:
# Catch the case where adj_inputs[0] is just a float
return adj_inputs[0]
elif self.compat.isconstant(block_variable.output):
R = block_variable.output._ad_function_space(prepared.function_space().mesh())
return self._adj_assign_constant(prepared, R)
else:
Expand All @@ -56,19 +62,24 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
# Linear combination
expr, adj_input_func = prepared
adj_output = self.backend.Function(adj_input_func.function_space())
if not isinstance(block_variable.output, self.backend.Constant):
if not self.compat.isconstant(block_variable.output):
diff_expr = ufl.algorithms.expand_derivatives(
ufl.derivative(expr, block_variable.saved_output, adj_input_func)
)
adj_output.assign(diff_expr)
else:
mesh = adj_output.function_space().mesh()
diff_expr = ufl.algorithms.expand_derivatives(
ufl.derivative(expr, block_variable.saved_output, self.backend.Constant(1.))
ufl.derivative(
expr,
block_variable.saved_output,
self.compat.create_constant(1., domain=mesh)
)
)
adj_output.assign(diff_expr)
return adj_output.vector().inner(adj_input_func.vector())

if isinstance(block_variable.output, self.backend.Constant):
if self.compat.isconstant(block_variable.output):
R = block_variable.output._ad_function_space(adj_output.function_space().mesh())
return self._adj_assign_constant(adj_output, R)
else:
Expand Down
8 changes: 4 additions & 4 deletions dolfin_adjoint_common/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepar
c = block_variable.output
c_rep = block_variable.saved_output

if isinstance(c, self.backend.Function):
trial_function = self.backend.TrialFunction(c.function_space())
elif isinstance(c, self.backend.Constant):
if self.compat.isconstant(c):
mesh = self.compat.extract_mesh_from_form(F_form)
trial_function = self.backend.TrialFunction(c._ad_function_space(mesh))
elif isinstance(c, self.backend.Function):
trial_function = self.backend.TrialFunction(c.function_space())
elif isinstance(c, self.compat.ExpressionType):
mesh = F_form.ufl_domain().ufl_cargo()
c_fs = c._ad_function_space(mesh)
Expand Down Expand Up @@ -376,7 +376,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_v
tmp_bc = self.compat.create_bc(c, value=self.compat.extract_subfunction(adj_sol2_bdy, c.function_space()))
return [tmp_bc]

if isinstance(c_rep, self.backend.Constant):
if self.compat.isconstant(c_rep):
mesh = self.compat.extract_mesh_from_form(F_form)
W = c._ad_function_space(mesh)
elif isinstance(c, self.compat.ExpressionType):
Expand Down
28 changes: 28 additions & 0 deletions dolfin_adjoint_common/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,27 @@ def type_cast_function(obj, cls):
return function_from_vector(obj.function_space(), obj.vector(), cls=cls)
compat.type_cast_function = type_cast_function

def create_constant(*args, **kwargs):
"""Initialises a firedrake.Constant object and returns it."""
from firedrake import Constant
return Constant(*args, **kwargs)
compat.create_constant = create_constant

def create_function(*args, **kwargs):
"""Initialises a firedrake.Function object and returns it."""
from firedrake import Function
return Function(*args, **kwargs)
compat.create_function = create_function

def isconstant(expr):
"""Check whether expression is constant type.
In firedrake this is a function in the real space
Ie: `firedrake.Function(FunctionSpace(mesh, "R"))`"""
if isinstance(expr, backend.Constant):
raise ValueError("Firedrake Constant requires a domain to work with pyadjoint")
return isinstance(expr, backend.Function) and expr.ufl_element().family() == "Real"
compat.isconstant = isconstant

else:
compat.Expression = backend.Expression
compat.MatrixType = (backend.cpp.la.Matrix, backend.cpp.la.GenericMatrix)
Expand Down Expand Up @@ -328,10 +343,23 @@ def type_cast_function(obj, cls):
return cls(obj.function_space(), obj._cpp_object)
compat.type_cast_function = type_cast_function

def create_constant(*args, **kwargs):
"""Initialise a fenics_adjoint.Constant object and return it."""
from fenics_adjoint import Constant
# Dolfin constants do not have domains
_ = kwargs.pop("domain", None)
return Constant(*args, **kwargs)
compat.create_constant = create_constant

def create_function(*args, **kwargs):
"""Initialises a fenics_adjoint.Function object and returns it."""
from fenics_adjoint import Function
return Function(*args, **kwargs)
compat.create_function = create_function

def isconstant(expr):
"""Check whether expression is constant type."""
return isinstance(expr, backend.Constant)
compat.isconstant = isconstant

return compat
8 changes: 4 additions & 4 deletions tests/firedrake_adjoint/test_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_assign_tlm_with_constant():
x = SpatialCoordinate(mesh)
f = interpolate(x[0], V)
g = interpolate(sin(x[0]), V)
c = Constant(5.0)
c = Constant(5.0, domain=mesh)

u = Function(V)
u.interpolate(c * f**2)
Expand Down Expand Up @@ -145,8 +145,8 @@ def test_assign_with_constant():

x = SpatialCoordinate(mesh)
f = interpolate(x[0], V)
c = Constant(3.0)
d = Constant(2.0)
c = Constant(3.0, domain=mesh)
d = Constant(2.0, domain=mesh)
u = Function(V)

u.assign(c*f+d**3)
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_assign_constant_scale():
V = VectorFunctionSpace(mesh, "CG", 1)

f = Function(V)
c = Constant(2.0)
c = Constant(2.0, domain=mesh)
x, y = SpatialCoordinate(mesh)
g = interpolate(as_vector([sin(y)+x, cos(x)*y]), V)

Expand Down
Loading

0 comments on commit 66567ba

Please sign in to comment.