diff --git a/mphys/mask_converter.py b/mphys/mask_converter.py index 72d3aa38..ab16a9b4 100644 --- a/mphys/mask_converter.py +++ b/mphys/mask_converter.py @@ -1,3 +1,4 @@ +import numpy as np import openmdao.api as om @@ -48,25 +49,48 @@ def setup(self): mask = self.options['mask'] self.add_input(input.name, shape=input.shape, tags=input.tags, distributed=distributed) - self.add_output(output.name, shape=output.shape, tags=output.tags, val=self.options['init_output'], distributed=distributed) + + if isinstance(output, list): + if len(output) != len(mask): + raise ValueError("Output length and mask length not equal") + for i in range(len(output)): + self.add_output(output[i].name, shape=output[i].shape, tags=output[i].tags, val=self.options['init_output'], distributed=distributed) + else: + self.add_output(output.name, shape=output.shape, tags=output.tags, val=self.options['init_output'], distributed=distributed) def compute(self, inputs, outputs): input = self.options['input'] - output = self.options['output'] mask = self.options['mask'] - outputs[output.name] = inputs[input.name][mask] + output = self.options['output'] + + if isinstance(output, list): + for i in range(len(output)): + outputs[output[i].name] = inputs[input.name][mask[i]] + else: + outputs[output.name] = inputs[input.name][mask] def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode): input = self.options['input'] output = self.options['output'] mask = self.options['mask'] - if mode == 'fwd': - if input.name in d_inputs and output.name in d_outputs: - d_outputs[output.name] += d_inputs[input.name][mask] - if mode == 'rev': - if input.name in d_inputs and output.name in d_outputs: - d_inputs[input.name][mask] += d_outputs[output.name] + if isinstance(output, list): + for i in range(len(output)): + if mode == 'fwd': + if input.name in d_inputs and output[i].name in d_outputs: + d_outputs[output[i].name] += d_inputs[input.name][mask[i]] + + if mode == 'rev': + if input.name in d_inputs and output[i].name in d_outputs: + d_inputs[input.name][mask[i]] += d_outputs[output[i].name] + else: + if mode == 'fwd': + if input.name in d_inputs and output.name in d_outputs: + d_outputs[output.name] += d_inputs[input.name][mask] + + if mode == 'rev': + if input.name in d_inputs and output.name in d_outputs: + d_inputs[input.name][mask] += d_outputs[output.name] class UnmaskedConverter(om.ExplicitComponent): """ @@ -103,7 +127,20 @@ def setup(self): output = self.options['output'] mask = self.options['mask'] - self.add_input(input.name, shape=input.shape, tags=input.tags, distributed=distributed) + if isinstance(input, list): + if len(input) != len(mask): + raise ValueError("Input length and mask length not equal") + + for i in range(len(input)-1): + for j in range(i+1, len(input)): + if (np.any(np.logical_and(mask[i], mask[j]))): + raise RuntimeWarning("Overlapping masking arrays, values will conflict.") + + for i in range(len(input)): + self.add_input(input[i].name, shape=input[i].shape, tags=input[i].tags, distributed=distributed) + else: + self.add_input(input.name, shape=input.shape, tags=input.tags, distributed=distributed) + self.add_output(output.name, shape=output.shape, tags=output.tags, distributed=distributed) def compute(self, inputs, outputs): @@ -112,16 +149,32 @@ def compute(self, inputs, outputs): mask = self.options['mask'] def_vals = self.options['default_values'] outputs[output.name][:] = def_vals - outputs[output.name][mask] = inputs[input.name] + + if isinstance(input, list): + for i in range(len(input)): + outputs[output.name][mask[i]] = inputs[input[i].name] + else: + outputs[output.name][mask] = inputs[input.name] def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode): input = self.options['input'] output = self.options['output'] mask = self.options['mask'] - if mode == 'fwd': - if input.name in d_inputs and output.name in d_outputs: - d_outputs[output.name][mask] += d_inputs[input.name] - if mode == 'rev': - if input.name in d_inputs and output.name in d_outputs: - d_inputs[input.name] += d_outputs[output.name][mask] + if isinstance(input, list): + for i in range(len(input)): + if mode == 'fwd': + if input[i].name in d_inputs and output.name in d_outputs: + d_outputs[output.name][mask[i]] += d_inputs[input[i].name] + + if mode == 'rev': + if input[i].name in d_inputs and output.name in d_outputs: + d_inputs[input[i].name] += d_outputs[output.name][mask[i]] + else: + if mode == 'fwd': + if input.name in d_inputs and output.name in d_outputs: + d_outputs[output.name][mask] += d_inputs[input.name] + + if mode == 'rev': + if input.name in d_inputs and output.name in d_outputs: + d_inputs[input.name] += d_outputs[output.name][mask] diff --git a/tests/unit_tests/test_mask_converter.py b/tests/unit_tests/test_mask_converter.py index 54abcf71..722ad4e2 100644 --- a/tests/unit_tests/test_mask_converter.py +++ b/tests/unit_tests/test_mask_converter.py @@ -9,7 +9,7 @@ from openmdao.utils.assert_utils import assert_near_equal -class TestMaskConverter(unittest.TestCase): +class TestMaskConverterSingle(unittest.TestCase): N_PROCS = 1 #TODO should be 2 or more but there is a bug in OM currently def setUp(self): @@ -59,5 +59,76 @@ def test_check_partials(self): assert_near_equal(rel_error.forward, 0.0, tolerance=tol) assert_near_equal(rel_error.forward_reverse, 0.0, tolerance=tol) +class TestMaskConverterMulti(unittest.TestCase): + N_PROCS = 1 #TODO should be 2 or more but there is a bug in OM currently + + def setUp(self): + self.common = CommonMethods() + self.prob = om.Problem() + inputs = self.prob.model.add_subsystem('inputs', om.IndepVarComp()) + + inputs.add_output('unmasked_input', val=np.ones(10, dtype=float), distributed=True) + inputs.add_output('masked_input_1', val=np.arange(5, dtype=float), distributed=True) + inputs.add_output('masked_input_2', val=np.arange(5, dtype=float), distributed=True) + + # Create a mask that masks every other entry of the input array + mask_input = MaskedVariableDescription('unmasked_input', shape=10, tags=['mphys_coordinates']) + mask_output = [ + MaskedVariableDescription('masked_output_1', shape=5, tags=['mphys_coordinates']), + MaskedVariableDescription('masked_output_2', shape=5, tags=['mphys_coordinates']), + ] + mask = [ + np.zeros([10], dtype=bool), + np.zeros([10], dtype=bool), + ] + mask[0][0:5] = True + mask[0][5:10] = False + mask[1][0:5] = False + mask[1][5:10] = True + masker = MaskedConverter(input=mask_input, output=mask_output, mask=mask, distributed=True) + + self.prob.model.add_subsystem('masker', masker) + + unmask_input = [ + MaskedVariableDescription('masked_input_1', shape=5, tags=['mphys_coordinates']), + MaskedVariableDescription('masked_input_2', shape=5, tags=['mphys_coordinates']), + ] + unmask_output = MaskedVariableDescription('unmasked_output', shape=10, tags=['mphys_coordinates']) + unmasker = UnmaskedConverter(input=unmask_input, output=unmask_output, mask=mask, distributed=True, + default_values=1.0) + + self.prob.model.add_subsystem('unmasker', unmasker) + + self.prob.model.connect('inputs.unmasked_input', 'masker.unmasked_input') + self.prob.model.connect('inputs.masked_input_1', 'unmasker.masked_input_1') + self.prob.model.connect('inputs.masked_input_2', 'unmasker.masked_input_2') + + self.prob.setup(force_alloc_complex=True) + + def test_run_model(self): + self.common.test_run_model(self, write_n2=False) + + def test_check_partials(self): + partials = self.prob.check_partials(compact_print=True, method='cs') + tol = 1e-9 + + rel_error = partials['masker'][('masked_output_1', 'unmasked_input')]['rel error'] + assert_near_equal(rel_error.reverse, 0.0, tolerance=tol) + assert_near_equal(rel_error.forward, 0.0, tolerance=tol) + assert_near_equal(rel_error.forward_reverse, 0.0, tolerance=tol) + rel_error = partials['masker'][('masked_output_2', 'unmasked_input')]['rel error'] + assert_near_equal(rel_error.reverse, 0.0, tolerance=tol) + assert_near_equal(rel_error.forward, 0.0, tolerance=tol) + assert_near_equal(rel_error.forward_reverse, 0.0, tolerance=tol) + + rel_error = partials['unmasker'][('unmasked_output', 'masked_input_1')]['rel error'] + assert_near_equal(rel_error.reverse, 0.0, tolerance=tol) + assert_near_equal(rel_error.forward, 0.0, tolerance=tol) + assert_near_equal(rel_error.forward_reverse, 0.0, tolerance=tol) + rel_error = partials['unmasker'][('unmasked_output', 'masked_input_2')]['rel error'] + assert_near_equal(rel_error.reverse, 0.0, tolerance=tol) + assert_near_equal(rel_error.forward, 0.0, tolerance=tol) + assert_near_equal(rel_error.forward_reverse, 0.0, tolerance=tol) + if __name__ == '__main__': unittest.main()