Skip to content

Commit

Permalink
Merge pull request #108 from bernardopacini/expandingMaskConverter
Browse files Browse the repository at this point in the history
Expanding mask converter
  • Loading branch information
timryanb authored Jul 14, 2022
2 parents c82edf6 + af83d1c commit d80d0cd
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 18 deletions.
87 changes: 70 additions & 17 deletions mphys/mask_converter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import openmdao.api as om


Expand Down Expand Up @@ -48,25 +49,48 @@ def setup(self):
mask = self.options['mask']

self.add_input(input.name, shape=input.shape, tags=input.tags, distributed=distributed)
self.add_output(output.name, shape=output.shape, tags=output.tags, val=self.options['init_output'], distributed=distributed)

if isinstance(output, list):
if len(output) != len(mask):
raise ValueError("Output length and mask length not equal")
for i in range(len(output)):
self.add_output(output[i].name, shape=output[i].shape, tags=output[i].tags, val=self.options['init_output'], distributed=distributed)
else:
self.add_output(output.name, shape=output.shape, tags=output.tags, val=self.options['init_output'], distributed=distributed)

def compute(self, inputs, outputs):
input = self.options['input']
output = self.options['output']
mask = self.options['mask']
outputs[output.name] = inputs[input.name][mask]
output = self.options['output']

if isinstance(output, list):
for i in range(len(output)):
outputs[output[i].name] = inputs[input.name][mask[i]]
else:
outputs[output.name] = inputs[input.name][mask]

def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode):
input = self.options['input']
output = self.options['output']
mask = self.options['mask']
if mode == 'fwd':
if input.name in d_inputs and output.name in d_outputs:
d_outputs[output.name] += d_inputs[input.name][mask]

if mode == 'rev':
if input.name in d_inputs and output.name in d_outputs:
d_inputs[input.name][mask] += d_outputs[output.name]
if isinstance(output, list):
for i in range(len(output)):
if mode == 'fwd':
if input.name in d_inputs and output[i].name in d_outputs:
d_outputs[output[i].name] += d_inputs[input.name][mask[i]]

if mode == 'rev':
if input.name in d_inputs and output[i].name in d_outputs:
d_inputs[input.name][mask[i]] += d_outputs[output[i].name]
else:
if mode == 'fwd':
if input.name in d_inputs and output.name in d_outputs:
d_outputs[output.name] += d_inputs[input.name][mask]

if mode == 'rev':
if input.name in d_inputs and output.name in d_outputs:
d_inputs[input.name][mask] += d_outputs[output.name]

class UnmaskedConverter(om.ExplicitComponent):
"""
Expand Down Expand Up @@ -103,7 +127,20 @@ def setup(self):
output = self.options['output']
mask = self.options['mask']

self.add_input(input.name, shape=input.shape, tags=input.tags, distributed=distributed)
if isinstance(input, list):
if len(input) != len(mask):
raise ValueError("Input length and mask length not equal")

for i in range(len(input)-1):
for j in range(i+1, len(input)):
if (np.any(np.logical_and(mask[i], mask[j]))):
raise RuntimeWarning("Overlapping masking arrays, values will conflict.")

for i in range(len(input)):
self.add_input(input[i].name, shape=input[i].shape, tags=input[i].tags, distributed=distributed)
else:
self.add_input(input.name, shape=input.shape, tags=input.tags, distributed=distributed)

self.add_output(output.name, shape=output.shape, tags=output.tags, distributed=distributed)

def compute(self, inputs, outputs):
Expand All @@ -112,16 +149,32 @@ def compute(self, inputs, outputs):
mask = self.options['mask']
def_vals = self.options['default_values']
outputs[output.name][:] = def_vals
outputs[output.name][mask] = inputs[input.name]

if isinstance(input, list):
for i in range(len(input)):
outputs[output.name][mask[i]] = inputs[input[i].name]
else:
outputs[output.name][mask] = inputs[input.name]

def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode):
input = self.options['input']
output = self.options['output']
mask = self.options['mask']
if mode == 'fwd':
if input.name in d_inputs and output.name in d_outputs:
d_outputs[output.name][mask] += d_inputs[input.name]

if mode == 'rev':
if input.name in d_inputs and output.name in d_outputs:
d_inputs[input.name] += d_outputs[output.name][mask]
if isinstance(input, list):
for i in range(len(input)):
if mode == 'fwd':
if input[i].name in d_inputs and output.name in d_outputs:
d_outputs[output.name][mask[i]] += d_inputs[input[i].name]

if mode == 'rev':
if input[i].name in d_inputs and output.name in d_outputs:
d_inputs[input[i].name] += d_outputs[output.name][mask[i]]
else:
if mode == 'fwd':
if input.name in d_inputs and output.name in d_outputs:
d_outputs[output.name][mask] += d_inputs[input.name]

if mode == 'rev':
if input.name in d_inputs and output.name in d_outputs:
d_inputs[input.name] += d_outputs[output.name][mask]
73 changes: 72 additions & 1 deletion tests/unit_tests/test_mask_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from openmdao.utils.assert_utils import assert_near_equal


class TestMaskConverter(unittest.TestCase):
class TestMaskConverterSingle(unittest.TestCase):
N_PROCS = 1 #TODO should be 2 or more but there is a bug in OM currently

def setUp(self):
Expand Down Expand Up @@ -59,5 +59,76 @@ def test_check_partials(self):
assert_near_equal(rel_error.forward, 0.0, tolerance=tol)
assert_near_equal(rel_error.forward_reverse, 0.0, tolerance=tol)

class TestMaskConverterMulti(unittest.TestCase):
N_PROCS = 1 #TODO should be 2 or more but there is a bug in OM currently

def setUp(self):
self.common = CommonMethods()
self.prob = om.Problem()
inputs = self.prob.model.add_subsystem('inputs', om.IndepVarComp())

inputs.add_output('unmasked_input', val=np.ones(10, dtype=float), distributed=True)
inputs.add_output('masked_input_1', val=np.arange(5, dtype=float), distributed=True)
inputs.add_output('masked_input_2', val=np.arange(5, dtype=float), distributed=True)

# Create a mask that masks every other entry of the input array
mask_input = MaskedVariableDescription('unmasked_input', shape=10, tags=['mphys_coordinates'])
mask_output = [
MaskedVariableDescription('masked_output_1', shape=5, tags=['mphys_coordinates']),
MaskedVariableDescription('masked_output_2', shape=5, tags=['mphys_coordinates']),
]
mask = [
np.zeros([10], dtype=bool),
np.zeros([10], dtype=bool),
]
mask[0][0:5] = True
mask[0][5:10] = False
mask[1][0:5] = False
mask[1][5:10] = True
masker = MaskedConverter(input=mask_input, output=mask_output, mask=mask, distributed=True)

self.prob.model.add_subsystem('masker', masker)

unmask_input = [
MaskedVariableDescription('masked_input_1', shape=5, tags=['mphys_coordinates']),
MaskedVariableDescription('masked_input_2', shape=5, tags=['mphys_coordinates']),
]
unmask_output = MaskedVariableDescription('unmasked_output', shape=10, tags=['mphys_coordinates'])
unmasker = UnmaskedConverter(input=unmask_input, output=unmask_output, mask=mask, distributed=True,
default_values=1.0)

self.prob.model.add_subsystem('unmasker', unmasker)

self.prob.model.connect('inputs.unmasked_input', 'masker.unmasked_input')
self.prob.model.connect('inputs.masked_input_1', 'unmasker.masked_input_1')
self.prob.model.connect('inputs.masked_input_2', 'unmasker.masked_input_2')

self.prob.setup(force_alloc_complex=True)

def test_run_model(self):
self.common.test_run_model(self, write_n2=False)

def test_check_partials(self):
partials = self.prob.check_partials(compact_print=True, method='cs')
tol = 1e-9

rel_error = partials['masker'][('masked_output_1', 'unmasked_input')]['rel error']
assert_near_equal(rel_error.reverse, 0.0, tolerance=tol)
assert_near_equal(rel_error.forward, 0.0, tolerance=tol)
assert_near_equal(rel_error.forward_reverse, 0.0, tolerance=tol)
rel_error = partials['masker'][('masked_output_2', 'unmasked_input')]['rel error']
assert_near_equal(rel_error.reverse, 0.0, tolerance=tol)
assert_near_equal(rel_error.forward, 0.0, tolerance=tol)
assert_near_equal(rel_error.forward_reverse, 0.0, tolerance=tol)

rel_error = partials['unmasker'][('unmasked_output', 'masked_input_1')]['rel error']
assert_near_equal(rel_error.reverse, 0.0, tolerance=tol)
assert_near_equal(rel_error.forward, 0.0, tolerance=tol)
assert_near_equal(rel_error.forward_reverse, 0.0, tolerance=tol)
rel_error = partials['unmasker'][('unmasked_output', 'masked_input_2')]['rel error']
assert_near_equal(rel_error.reverse, 0.0, tolerance=tol)
assert_near_equal(rel_error.forward, 0.0, tolerance=tol)
assert_near_equal(rel_error.forward_reverse, 0.0, tolerance=tol)

if __name__ == '__main__':
unittest.main()

0 comments on commit d80d0cd

Please sign in to comment.