From 9ec7ba7cd329638e759c3cb804241c15ec5bfbd4 Mon Sep 17 00:00:00 2001 From: Ian Williamson Date: Thu, 30 Sep 2021 12:56:14 -0700 Subject: [PATCH] Fix incorrect backward `EigenmodeCoefficient` gradients and update test case (#1593) * Fix backward gradients in EigenmodeCoefficient and update JAX adjoint test to check gradients in both forward and backward directions Flip sign of `amp` based on group velocity direction Original fix * Scale kpoint by abs(center_frequency) and apply yapf formatting * Fix parenthesis in center frequency calculation, fix eigenmode coefficient creation comprehension in test, fix design resolution calculation in test --- python/adjoint/objective.py | 70 +++-- python/tests/test_adjoint_jax.py | 445 ++++++++++++++++--------------- 2 files changed, 281 insertions(+), 234 deletions(-) diff --git a/python/adjoint/objective.py b/python/adjoint/objective.py index e1d0159fb..11d78e778 100644 --- a/python/adjoint/objective.py +++ b/python/adjoint/objective.py @@ -115,22 +115,39 @@ def _create_time_profile(self, fwidth_frac=0.1): class EigenmodeCoefficient(ObjectiveQuantity): + """A frequency-dependent eigenmode coefficient. + Attributes: + volume: the volume over which the eigenmode coefficient is calculated. + mode: the eigenmode number. + forward: whether the forward or backward mode coefficient is returned as + the result of the evaluation. + kpoint_func: an optional k-point function to use when evaluating the eigenmode + coefficient. When specified, this overrides the effect of `forward`. + kpoint_func_overlap_idx: the index of the mode coefficient to return when + specifying `kpoint_func`. When specified, this overrides the effect of + `forward` and should have a value of either 0 or 1. + """ def __init__(self, sim, volume, mode, forward=True, kpoint_func=None, + kpoint_func_overlap_idx=0, decimation_factor=0, **kwargs): super().__init__(sim) + if kpoint_func_overlap_idx not in [0, 1]: + raise ValueError( + '`kpoint_func_overlap_idx` should be either 0 or 1, but got %d' + % (kpoint_func_overlap_idx, )) self.volume = volume self.mode = mode self.forward = forward self.kpoint_func = kpoint_func + self.kpoint_func_overlap_idx = kpoint_func_overlap_idx self.eigenmode_kwargs = kwargs self._monitor = None - self._normal_direction = None self._cscale = None self.decimation_factor = decimation_factor @@ -142,28 +159,26 @@ def register_monitors(self, frequencies): yee_grid=True, decimation_factor=self.decimation_factor, ) - self._normal_direction = self._monitor.normal_direction return self._monitor def place_adjoint_source(self, dJ): dJ = np.atleast_1d(dJ) - direction_scalar = -1 if self.forward else 1 - time_src = self._create_time_profile() - if self.kpoint_func is None: - if self._normal_direction == 0: - k0 = direction_scalar * mp.Vector3(x=1) - elif self._normal_direction == 1: - k0 = direction_scalar * mp.Vector3(y=1) - elif self._normal_direction == 2: - k0 = direction_scalar * mp.Vector3(z=1) - else: - k0 = direction_scalar * self.kpoint_func(time_src.frequency, 1) if dJ.ndim == 2: dJ = np.sum(dJ, axis=1) - da_dE = 0.5 * self._cscale # scalar popping out of derivative - + time_src = self._create_time_profile() + da_dE = 0.5 * self._cscale scale = self._adj_src_scale() + if self.kpoint_func: + eig_kpoint = -1 * self.kpoint_func(time_src.frequency, self.mode) + else: + center_frequency = 0.5 * (np.min(self.frequencies) + np.max( + self.frequencies)) + direction = mp.Vector3( + *(np.eye(3)[self._monitor.normal_direction] * + np.abs(center_frequency))) + eig_kpoint = -1 * direction if self.forward else direction + if self._frequencies.size == 1: amp = da_dE * dJ * scale src = time_src @@ -176,12 +191,11 @@ def place_adjoint_source(self, dJ): self.sim.fields.dt, ) amp = 1 - source = mp.EigenModeSource( src, eig_band=self.mode, direction=mp.NO_DIRECTION, - eig_kpoint=k0, + eig_kpoint=eig_kpoint, amplitude=amp, eig_match_freq=True, size=self.volume.size, @@ -191,17 +205,27 @@ def place_adjoint_source(self, dJ): return [source] def __call__(self): - direction = mp.NO_DIRECTION if self.kpoint_func else mp.AUTOMATIC + if self.kpoint_func: + kpoint_func = self.kpoint_func + overlap_idx = self.kpoint_func_overlap_idx + else: + center_frequency = 0.5 * (np.min(self.frequencies) + np.max( + self.frequencies)) + kpoint = mp.Vector3(*(np.eye(3)[self._monitor.normal_direction] * + np.abs(center_frequency))) + kpoint_func = lambda *not_used: kpoint if self.forward else -1 * kpoint + overlap_idx = 0 ob = self.sim.get_eigenmode_coefficients( self._monitor, [self.mode], - direction=direction, - kpoint_func=self.kpoint_func, + direction=mp.NO_DIRECTION, + kpoint_func=kpoint_func, **self.eigenmode_kwargs, ) - # record eigenmode coefficients for scaling - self._eval = np.squeeze(ob.alpha[:, :, int(not self.forward)]) - self._cscale = ob.cscale # pull scaling factor + overlaps = ob.alpha.squeeze(axis=0) + assert overlaps.ndim == 2 + self._eval = overlaps[:, overlap_idx] + self._cscale = ob.cscale return self._eval diff --git a/python/tests/test_adjoint_jax.py b/python/tests/test_adjoint_jax.py index 0cc927b3f..e3bcad587 100644 --- a/python/tests/test_adjoint_jax.py +++ b/python/tests/test_adjoint_jax.py @@ -16,234 +16,257 @@ _FD_STEP = 1e-4 # The tolerance for the adjoint and finite difference gradient comparison -_TOL = 0.1 if mp.is_single_precision() else 0.02 +_TOL = 0.1 if mp.is_single_precision() else 0.025 mp.verbosity(0) def build_straight_wg_simulation( - wg_width=0.5, - wg_padding=1.0, - wg_length=1.0, - pml_width=1.0, - source_to_pml=0.5, - source_to_monitor=0.1, - frequencies=[1 / 1.55], - gaussian_rel_width=0.2, - sim_resolution=20, - design_region_resolution=20, + wg_width=0.5, + wg_padding=1.0, + wg_length=1.0, + pml_width=1.0, + source_to_pml=0.5, + source_to_monitor=0.1, + frequencies=[1 / 1.55], + gaussian_rel_width=0.2, + sim_resolution=20, + design_region_resolution=20, ): - """Builds a simulation of a straight waveguide with a design region segment.""" - design_region_shape = (1.0, wg_width) - - # Simulation domain size - sx = 2 * pml_width + 2 * wg_length + design_region_shape[0] - sy = 2 * pml_width + 2 * wg_padding + max( - wg_width, - design_region_shape[1], - ) - - # Mean / center frequency - fmean = onp.mean(frequencies) - - si = mp.Medium(index=3.4) - sio2 = mp.Medium(index=1.44) - - sources = [ - mp.EigenModeSource( - mp.GaussianSource(frequency=fmean, fwidth=fmean * gaussian_rel_width), - eig_band=1, - direction=mp.NO_DIRECTION, - eig_kpoint=mp.Vector3(1, 0, 0), - size=mp.Vector3(0, wg_width + 2 * wg_padding, 0), - center=[-sx / 2 + pml_width + source_to_pml, 0, 0], + """Builds a simulation of a straight waveguide with a design region segment.""" + design_region_shape = (1.0, wg_width) + + # Simulation domain size + sx = 2 * pml_width + 2 * wg_length + design_region_shape[0] + sy = 2 * pml_width + 2 * wg_padding + max( + wg_width, + design_region_shape[1], ) - ] - - nx = int(design_region_resolution * design_region_shape[0]) - ny = int(design_region_resolution * design_region_shape[1]) - mat_grid = mp.MaterialGrid( - mp.Vector3(nx, ny), - sio2, - si, - grid_type='U_DEFAULT', - ) - - design_regions = [ - mpa.DesignRegion( - mat_grid, - volume=mp.Volume( - center=mp.Vector3(), - size=mp.Vector3( - design_region_shape[0], - design_region_shape[1], - 0, + + # Mean / center frequency + fmean = onp.mean(frequencies) + + si = mp.Medium(index=3.4) + sio2 = mp.Medium(index=1.44) + + sources = [ + mp.EigenModeSource( + mp.GaussianSource(frequency=fmean, + fwidth=fmean * gaussian_rel_width), + eig_band=1, + direction=mp.NO_DIRECTION, + eig_kpoint=mp.Vector3(1, 0, 0), + size=mp.Vector3(0, wg_width + 2 * wg_padding, 0), + center=[-sx / 2 + pml_width + source_to_pml, 0, 0], + ), + mp.EigenModeSource( + mp.GaussianSource(frequency=fmean, + fwidth=fmean * gaussian_rel_width), + eig_band=1, + direction=mp.NO_DIRECTION, + eig_kpoint=mp.Vector3(-1, 0, 0), + size=mp.Vector3(0, wg_width + 2 * wg_padding, 0), + center=[sx / 2 - pml_width - source_to_pml, 0, 0], ), - ), + ] + + nx = int(design_region_resolution * design_region_shape[0]) + 1 + ny = int(design_region_resolution * design_region_shape[1]) + 1 + mat_grid = mp.MaterialGrid( + mp.Vector3(nx, ny), + sio2, + si, + grid_type='U_DEFAULT', ) - ] - - geometry = [ - mp.Block( - center=mp.Vector3(x=-design_region_shape[0] / 2 - wg_length / 2 - pml_width / 2), - material=si, - size=mp.Vector3(wg_length + pml_width, wg_width, 0)), # left wg - mp.Block( - center=mp.Vector3(x=+design_region_shape[0] / 2 + wg_length / 2 + pml_width / 2), - material=si, - size=mp.Vector3(wg_length + pml_width, wg_width, 0)), # right wg - mp.Block( - center=design_regions[0].center, - size=design_regions[0].size, - material=mat_grid), # design region - ] - - simulation = mp.Simulation( - cell_size=mp.Vector3(sx, sy), - boundary_layers=[mp.PML(pml_width)], - geometry=geometry, - sources=sources, - resolution=sim_resolution, - ) - - monitor_centers = [ - mp.Vector3(-sx / 2 + pml_width + source_to_pml + source_to_monitor), - mp.Vector3(sx / 2 - pml_width - source_to_pml - source_to_monitor), - ] - monitor_size = mp.Vector3(y=wg_width + 2 * wg_padding) - - monitors = [ - mpa.EigenmodeCoefficient( - simulation, - mp.Volume(center=center, size=monitor_size), - mode=1, - forward=True, - decimation_factor=1) for center in monitor_centers - ] - return simulation, sources, monitors, design_regions, frequencies - -class UtilsTest(unittest.TestCase): - def setUp(self): - super().setUp() - ( - self.simulation, - self.sources, - self.monitors, - self.design_regions, - self.frequencies, - ) = build_straight_wg_simulation() - - def test_mode_monitor_helpers(self): - mpa.utils.register_monitors(self.monitors, self.frequencies) - self.simulation.run(until=100) - monitor_values = mpa.utils.gather_monitor_values(self.monitors) - self.assertEqual(monitor_values.dtype, onp.complex128) - self.assertEqual(monitor_values.shape, - (len(self.monitors), len(self.frequencies))) - - def test_design_region_monitor_helpers(self): - design_region_monitors = mpa.utils.install_design_region_monitors( - self.simulation, - self.design_regions, - self.frequencies, - ) - self.simulation.run(until=100) - design_region_fields = mpa.utils.gather_design_region_fields( - self.simulation, - design_region_monitors, - self.frequencies, + design_regions = [ + mpa.DesignRegion( + mat_grid, + volume=mp.Volume( + center=mp.Vector3(), + size=mp.Vector3( + design_region_shape[0], + design_region_shape[1], + 0, + ), + ), + ) + ] + + geometry = [ + mp.Block(center=mp.Vector3(x=-design_region_shape[0] / 2 - + wg_length / 2 - pml_width / 2), + material=si, + size=mp.Vector3(wg_length + pml_width, wg_width, + 0)), # left wg + mp.Block(center=mp.Vector3(x=+design_region_shape[0] / 2 + + wg_length / 2 + pml_width / 2), + material=si, + size=mp.Vector3(wg_length + pml_width, wg_width, + 0)), # right wg + mp.Block(center=design_regions[0].center, + size=design_regions[0].size, + material=mat_grid), # design region + ] + + simulation = mp.Simulation( + cell_size=mp.Vector3(sx, sy), + boundary_layers=[mp.PML(pml_width)], + geometry=geometry, + sources=sources, + resolution=sim_resolution, ) - self.assertIsInstance(design_region_fields, list) - self.assertEqual(len(design_region_fields), len(self.design_regions)) - - self.assertIsInstance(design_region_fields[0], list) - self.assertEqual(len(design_region_fields[0]), len(mpa.utils._ADJOINT_FIELD_COMPONENTS)) + monitor_centers = [ + mp.Vector3(-sx / 2 + pml_width + source_to_pml + source_to_monitor), + mp.Vector3(sx / 2 - pml_width - source_to_pml - source_to_monitor), + ] + monitor_size = mp.Vector3(y=wg_width + 2 * wg_padding) - for value in design_region_fields[0]: - self.assertIsInstance(value, onp.ndarray) - self.assertEqual(value.ndim, 4) # dims: freq, x, y, pad - self.assertEqual(value.dtype, onp.complex128) + monitors = [ + mpa.EigenmodeCoefficient(simulation, + mp.Volume(center=center, size=monitor_size), + mode=1, + forward=forward, + decimation_factor=5) + for center in monitor_centers for forward in [True, False] + ] + return simulation, sources, monitors, design_regions, frequencies -class WrapperTest(ApproxComparisonTestCase): +class UtilsTest(unittest.TestCase): + def setUp(self): + super().setUp() + ( + self.simulation, + self.sources, + self.monitors, + self.design_regions, + self.frequencies, + ) = build_straight_wg_simulation() + + def test_mode_monitor_helpers(self): + mpa.utils.register_monitors(self.monitors, self.frequencies) + self.simulation.run(until=100) + monitor_values = mpa.utils.gather_monitor_values(self.monitors) + self.assertEqual(monitor_values.dtype, onp.complex128) + self.assertEqual(monitor_values.shape, + (len(self.monitors), len(self.frequencies))) + + def test_design_region_monitor_helpers(self): + design_region_monitors = mpa.utils.install_design_region_monitors( + self.simulation, + self.design_regions, + self.frequencies, + ) + self.simulation.run(until=100) + design_region_fields = mpa.utils.gather_design_region_fields( + self.simulation, + design_region_monitors, + self.frequencies, + ) + + self.assertIsInstance(design_region_fields, list) + self.assertEqual(len(design_region_fields), len(self.design_regions)) + + self.assertIsInstance(design_region_fields[0], list) + self.assertEqual(len(design_region_fields[0]), + len(mpa.utils._ADJOINT_FIELD_COMPONENTS)) + + for value in design_region_fields[0]: + self.assertIsInstance(value, onp.ndarray) + self.assertEqual(value.ndim, 4) # dims: freq, x, y, pad + self.assertEqual(value.dtype, onp.complex128) - @parameterized.parameterized.expand([ - ('1500_1550bw_01relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0), - ('1550_1600bw_02relative_gaussian', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0), - ('1500_1600bw_03relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0), - ]) - def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_variable_fill_value): - """Tests gradient from the JAX-Meep wrapper against finite differences.""" - ( - simulation, - sources, - monitors, - design_regions, - frequencies, - ) = build_straight_wg_simulation(frequencies=frequencies, gaussian_rel_width=gaussian_rel_width) - - wrapped_meep = mpa.MeepJaxWrapper( - simulation, - sources, - monitors, - design_regions, - frequencies, - measurement_interval=50.0, - dft_field_components=(mp.Ez,), - dft_threshold=1e-6, - minimum_run_time=0, - maximum_run_time=onp.inf, - until_after_sources=True - ) - design_shape = tuple(int(i) for i in design_regions[0].design_parameters.grid_size)[:2] - x = onp.ones(design_shape) * design_variable_fill_value - - # Define a loss function - def loss_fn(x): - monitor_values = wrapped_meep([x]) - t = monitor_values[1, :] / monitor_values[0, :] - # Mean transmission vs wavelength - return jnp.mean(jnp.square(jnp.abs(t))) - - value, adjoint_grad = jax.value_and_grad(loss_fn)(x) - - projection = [] - fd_projection = [] - - # Project along 5 random directions in the design parameter space. - for seed in range(5): - # Create dp - random_perturbation_vector = _FD_STEP * jax.random.normal( - jax.random.PRNGKey(seed), - x.shape, - ) - - # Calculate p + dp - x_perturbed = x + random_perturbation_vector - - # Calculate T(p + dp) - value_perturbed = loss_fn(x_perturbed) - - projection.append( - onp.dot( - random_perturbation_vector.ravel(), - adjoint_grad.ravel(), - )) - fd_projection.append(value_perturbed - value) - - projection = onp.stack(projection) - fd_projection = onp.stack(fd_projection) - - # Check that dp . ∇T ~ T(p + dp) - T(p) - self.assertClose( - projection, - fd_projection, - epsilon=_TOL, - ) +class WrapperTest(ApproxComparisonTestCase): + @parameterized.parameterized.expand([ + ('1500_1550bw_01relative_gaussian_port1', + onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0, 0), + ('1550_1600bw_02relative_gaussian_port1', + onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 0), + ('1500_1600bw_03relative_gaussian_port1', + onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 0), + ('1500_1550bw_01relative_gaussian_port2', + onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0, 1), + ('1550_1600bw_02relative_gaussian_port2', + onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 1), + ('1500_1600bw_03relative_gaussian_port2', + onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 1), + ]) + def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, + design_variable_fill_value, excite_port_idx): + """Tests gradient from the JAX-Meep wrapper against finite differences.""" + ( + simulation, + sources, + monitors, + design_regions, + frequencies, + ) = build_straight_wg_simulation(frequencies=frequencies, + gaussian_rel_width=gaussian_rel_width) + + design_shape = tuple( + int(i) for i in design_regions[0].design_parameters.grid_size)[:2] + x = onp.ones(design_shape) * design_variable_fill_value + + # Define a loss function + def loss_fn(x, excite_port_idx=0): + wrapped_meep = mpa.MeepJaxWrapper( + simulation, + [sources[excite_port_idx]], + monitors, + design_regions, + frequencies, + ) + monitor_values = wrapped_meep([x]) + s1p, s1m, s2m, s2p = monitor_values + if excite_port_idx == 0: + t = s2m / s1p + else: + t = s1m / s2p + # Mean transmission vs wavelength + t_mean = jnp.mean(jnp.square(jnp.abs(t))) + return t_mean + + value, adjoint_grad = jax.value_and_grad(loss_fn)( + x, excite_port_idx=excite_port_idx) + + projection = [] + fd_projection = [] + + # Project along 5 random directions in the design parameter space. + for seed in range(5): + # Create dp + random_perturbation_vector = _FD_STEP * jax.random.normal( + jax.random.PRNGKey(seed), + x.shape, + ) + + # Calculate p + dp + x_perturbed = x + random_perturbation_vector + + # Calculate T(p + dp) + value_perturbed = loss_fn(x_perturbed, + excite_port_idx=excite_port_idx) + + projection.append( + onp.dot( + random_perturbation_vector.ravel(), + adjoint_grad.ravel(), + )) + fd_projection.append(value_perturbed - value) + + projection = onp.stack(projection) + fd_projection = onp.stack(fd_projection) + + # Check that dp . ∇T ~ T(p + dp) - T(p) + self.assertClose( + projection, + fd_projection, + epsilon=_TOL, + ) if __name__ == '__main__': - unittest.main() + unittest.main()