Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect backward EigenmodeCoefficient gradients and update test case #1593

Merged
merged 3 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 41 additions & 24 deletions python/adjoint/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,39 @@ def _create_time_profile(self, fwidth_frac=0.1):


class EigenmodeCoefficient(ObjectiveQuantity):
"""A frequency-dependent eigenmode coefficient.
Attributes:
volume: the volume over which the eigenmode coefficient is calculated.
mode: the eigenmode number.
forward: whether the forward or backward mode coefficient is returned as
the result of the evaluation.
kpoint_func: an optional k-point function to use when evaluating the eigenmode
coefficient. When specified, this overrides the effect of `forward`.
kpoint_func_overlap_idx: the index of the mode coefficient to return when
specifying `kpoint_func`. When specified, this overrides the effect of
`forward` and should have a value of either 0 or 1.
"""
def __init__(self,
sim,
volume,
mode,
forward=True,
kpoint_func=None,
kpoint_func_overlap_idx=0,
decimation_factor=0,
**kwargs):
super().__init__(sim)
if kpoint_func_overlap_idx not in [0, 1]:
raise ValueError(
'`kpoint_func_overlap_idx` should be either 0 or 1, but got %d'
% (kpoint_func_overlap_idx, ))
self.volume = volume
self.mode = mode
self.forward = forward
self.kpoint_func = kpoint_func
self.kpoint_func_overlap_idx = kpoint_func_overlap_idx
self.eigenmode_kwargs = kwargs
self._monitor = None
self._normal_direction = None
self._cscale = None
self.decimation_factor = decimation_factor

Expand All @@ -142,28 +159,22 @@ def register_monitors(self, frequencies):
yee_grid=True,
decimation_factor=self.decimation_factor,
)
self._normal_direction = self._monitor.normal_direction
return self._monitor

def place_adjoint_source(self, dJ):
dJ = np.atleast_1d(dJ)
direction_scalar = -1 if self.forward else 1
time_src = self._create_time_profile()
if self.kpoint_func is None:
if self._normal_direction == 0:
k0 = direction_scalar * mp.Vector3(x=1)
elif self._normal_direction == 1:
k0 = direction_scalar * mp.Vector3(y=1)
elif self._normal_direction == 2:
k0 = direction_scalar * mp.Vector3(z=1)
else:
k0 = direction_scalar * self.kpoint_func(time_src.frequency, 1)
if dJ.ndim == 2:
dJ = np.sum(dJ, axis=1)
da_dE = 0.5 * self._cscale # scalar popping out of derivative

time_src = self._create_time_profile()
da_dE = 0.5 * self._cscale
scale = self._adj_src_scale()

if self.kpoint_func:
eig_kpoint = -1 * self.kpoint_func(time_src.frequency, self.mode)
else:
direction = mp.Vector3(*[float(v == 0) for v in self.volume.size])
eig_kpoint = -1 * direction if self.forward else direction
ianwilliamson marked this conversation as resolved.
Show resolved Hide resolved

if self._frequencies.size == 1:
amp = da_dE * dJ * scale
src = time_src
Expand All @@ -176,12 +187,11 @@ def place_adjoint_source(self, dJ):
self.sim.fields.dt,
)
amp = 1

source = mp.EigenModeSource(
src,
eig_band=self.mode,
direction=mp.NO_DIRECTION,
eig_kpoint=k0,
direction=mp.AUTOMATIC,
eig_kpoint=eig_kpoint,
amplitude=amp,
eig_match_freq=True,
size=self.volume.size,
Expand All @@ -191,17 +201,24 @@ def place_adjoint_source(self, dJ):
return [source]

def __call__(self):
direction = mp.NO_DIRECTION if self.kpoint_func else mp.AUTOMATIC
if self.kpoint_func:
kpoint_func = self.kpoint_func
overlap_idx = self.kpoint_func_overlap_idx
else:
direction = mp.Vector3(*[float(v == 0) for v in self.volume.size])
kpoint_func = lambda *not_used: direction if self.forward else -1 * direction
overlap_idx = 0
ob = self.sim.get_eigenmode_coefficients(
self._monitor,
[self.mode],
direction=direction,
kpoint_func=self.kpoint_func,
direction=mp.AUTOMATIC,
kpoint_func=kpoint_func,
**self.eigenmode_kwargs,
)
# record eigenmode coefficients for scaling
self._eval = np.squeeze(ob.alpha[:, :, int(not self.forward)])
self._cscale = ob.cscale # pull scaling factor
overlaps = ob.alpha.squeeze(axis=0)
assert overlaps.ndim == 2
self._eval = overlaps[:, overlap_idx]
self._cscale = ob.cscale
return self._eval


Expand Down
59 changes: 34 additions & 25 deletions python/tests/test_adjoint_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,15 @@ def build_straight_wg_simulation(
eig_kpoint=mp.Vector3(1, 0, 0),
size=mp.Vector3(0, wg_width + 2 * wg_padding, 0),
center=[-sx / 2 + pml_width + source_to_pml, 0, 0],
)
),
mp.EigenModeSource(
mp.GaussianSource(frequency=fmean, fwidth=fmean * gaussian_rel_width),
eig_band=1,
direction=mp.NO_DIRECTION,
eig_kpoint=mp.Vector3(-1, 0, 0),
size=mp.Vector3(0, wg_width + 2 * wg_padding, 0),
center=[sx / 2 - pml_width - source_to_pml, 0, 0],
),
]

nx = int(design_region_resolution * design_region_shape[0])
Expand Down Expand Up @@ -118,7 +126,7 @@ def build_straight_wg_simulation(
mp.Volume(center=center, size=monitor_size),
mode=1,
forward=True,
decimation_factor=1) for center in monitor_centers
decimation_factor=1) for center in monitor_centers for forward in [True, False]
]
return simulation, sources, monitors, design_regions, frequencies

Expand Down Expand Up @@ -170,11 +178,14 @@ def test_design_region_monitor_helpers(self):
class WrapperTest(ApproxComparisonTestCase):

@parameterized.parameterized.expand([
('1500_1550bw_01relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0),
('1550_1600bw_02relative_gaussian', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0),
('1500_1600bw_03relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0),
('1500_1550bw_01relative_gaussian_port1', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0, 0),
('1550_1600bw_02relative_gaussian_port1', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 0),
('1500_1600bw_03relative_gaussian_port1', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 0),
('1500_1550bw_01relative_gaussian_port2', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0, 1),
('1550_1600bw_02relative_gaussian_port2', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 1),
('1500_1600bw_03relative_gaussian_port2', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 1),
])
def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_variable_fill_value):
def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_variable_fill_value, excite_port_idx):
"""Tests gradient from the JAX-Meep wrapper against finite differences."""
(
simulation,
Expand All @@ -184,31 +195,29 @@ def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_vari
frequencies,
) = build_straight_wg_simulation(frequencies=frequencies, gaussian_rel_width=gaussian_rel_width)

wrapped_meep = mpa.MeepJaxWrapper(
simulation,
sources,
monitors,
design_regions,
frequencies,
measurement_interval=50.0,
dft_field_components=(mp.Ez,),
dft_threshold=1e-6,
minimum_run_time=0,
maximum_run_time=onp.inf,
until_after_sources=True
)

design_shape = tuple(int(i) for i in design_regions[0].design_parameters.grid_size)[:2]
x = onp.ones(design_shape) * design_variable_fill_value

# Define a loss function
def loss_fn(x):
def loss_fn(x, excite_port_idx=0):
wrapped_meep = mpa.MeepJaxWrapper(
simulation,
[sources[excite_port_idx]],
monitors,
design_regions,
frequencies,
)
monitor_values = wrapped_meep([x])
t = monitor_values[1, :] / monitor_values[0, :]
s1p, s1m, s2m, s2p = monitor_values
if excite_port_idx == 0:
t = s2m / s1p
else:
t = s1m / s2p
# Mean transmission vs wavelength
return jnp.mean(jnp.square(jnp.abs(t)))
t_mean = jnp.mean(jnp.square(jnp.abs(t)))
return t_mean

value, adjoint_grad = jax.value_and_grad(loss_fn)(x)
value, adjoint_grad = jax.value_and_grad(loss_fn)(x, excite_port_idx=excite_port_idx)

projection = []
fd_projection = []
Expand All @@ -225,7 +234,7 @@ def loss_fn(x):
x_perturbed = x + random_perturbation_vector

# Calculate T(p + dp)
value_perturbed = loss_fn(x_perturbed)
value_perturbed = loss_fn(x_perturbed, excite_port_idx=excite_port_idx)

projection.append(
onp.dot(
Expand Down