Skip to content

Commit

Permalink
Fix incorrect backward EigenmodeCoefficient gradients and update te…
Browse files Browse the repository at this point in the history
…st case (#1593)

* Fix backward gradients in EigenmodeCoefficient and update JAX adjoint test to check gradients in both forward and backward directions

Flip sign of `amp` based on group velocity direction

Original fix

* Scale kpoint by abs(center_frequency) and apply yapf formatting

* Fix parenthesis in center frequency calculation, fix eigenmode coefficient creation comprehension in test, fix design resolution calculation in test
  • Loading branch information
ianwilliamson authored Sep 30, 2021
1 parent 34ce8d0 commit 9ec7ba7
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 234 deletions.
70 changes: 47 additions & 23 deletions python/adjoint/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,39 @@ def _create_time_profile(self, fwidth_frac=0.1):


class EigenmodeCoefficient(ObjectiveQuantity):
"""A frequency-dependent eigenmode coefficient.
Attributes:
volume: the volume over which the eigenmode coefficient is calculated.
mode: the eigenmode number.
forward: whether the forward or backward mode coefficient is returned as
the result of the evaluation.
kpoint_func: an optional k-point function to use when evaluating the eigenmode
coefficient. When specified, this overrides the effect of `forward`.
kpoint_func_overlap_idx: the index of the mode coefficient to return when
specifying `kpoint_func`. When specified, this overrides the effect of
`forward` and should have a value of either 0 or 1.
"""
def __init__(self,
sim,
volume,
mode,
forward=True,
kpoint_func=None,
kpoint_func_overlap_idx=0,
decimation_factor=0,
**kwargs):
super().__init__(sim)
if kpoint_func_overlap_idx not in [0, 1]:
raise ValueError(
'`kpoint_func_overlap_idx` should be either 0 or 1, but got %d'
% (kpoint_func_overlap_idx, ))
self.volume = volume
self.mode = mode
self.forward = forward
self.kpoint_func = kpoint_func
self.kpoint_func_overlap_idx = kpoint_func_overlap_idx
self.eigenmode_kwargs = kwargs
self._monitor = None
self._normal_direction = None
self._cscale = None
self.decimation_factor = decimation_factor

Expand All @@ -142,28 +159,26 @@ def register_monitors(self, frequencies):
yee_grid=True,
decimation_factor=self.decimation_factor,
)
self._normal_direction = self._monitor.normal_direction
return self._monitor

def place_adjoint_source(self, dJ):
dJ = np.atleast_1d(dJ)
direction_scalar = -1 if self.forward else 1
time_src = self._create_time_profile()
if self.kpoint_func is None:
if self._normal_direction == 0:
k0 = direction_scalar * mp.Vector3(x=1)
elif self._normal_direction == 1:
k0 = direction_scalar * mp.Vector3(y=1)
elif self._normal_direction == 2:
k0 = direction_scalar * mp.Vector3(z=1)
else:
k0 = direction_scalar * self.kpoint_func(time_src.frequency, 1)
if dJ.ndim == 2:
dJ = np.sum(dJ, axis=1)
da_dE = 0.5 * self._cscale # scalar popping out of derivative

time_src = self._create_time_profile()
da_dE = 0.5 * self._cscale
scale = self._adj_src_scale()

if self.kpoint_func:
eig_kpoint = -1 * self.kpoint_func(time_src.frequency, self.mode)
else:
center_frequency = 0.5 * (np.min(self.frequencies) + np.max(
self.frequencies))
direction = mp.Vector3(
*(np.eye(3)[self._monitor.normal_direction] *
np.abs(center_frequency)))
eig_kpoint = -1 * direction if self.forward else direction

if self._frequencies.size == 1:
amp = da_dE * dJ * scale
src = time_src
Expand All @@ -176,12 +191,11 @@ def place_adjoint_source(self, dJ):
self.sim.fields.dt,
)
amp = 1

source = mp.EigenModeSource(
src,
eig_band=self.mode,
direction=mp.NO_DIRECTION,
eig_kpoint=k0,
eig_kpoint=eig_kpoint,
amplitude=amp,
eig_match_freq=True,
size=self.volume.size,
Expand All @@ -191,17 +205,27 @@ def place_adjoint_source(self, dJ):
return [source]

def __call__(self):
direction = mp.NO_DIRECTION if self.kpoint_func else mp.AUTOMATIC
if self.kpoint_func:
kpoint_func = self.kpoint_func
overlap_idx = self.kpoint_func_overlap_idx
else:
center_frequency = 0.5 * (np.min(self.frequencies) + np.max(
self.frequencies))
kpoint = mp.Vector3(*(np.eye(3)[self._monitor.normal_direction] *
np.abs(center_frequency)))
kpoint_func = lambda *not_used: kpoint if self.forward else -1 * kpoint
overlap_idx = 0
ob = self.sim.get_eigenmode_coefficients(
self._monitor,
[self.mode],
direction=direction,
kpoint_func=self.kpoint_func,
direction=mp.NO_DIRECTION,
kpoint_func=kpoint_func,
**self.eigenmode_kwargs,
)
# record eigenmode coefficients for scaling
self._eval = np.squeeze(ob.alpha[:, :, int(not self.forward)])
self._cscale = ob.cscale # pull scaling factor
overlaps = ob.alpha.squeeze(axis=0)
assert overlaps.ndim == 2
self._eval = overlaps[:, overlap_idx]
self._cscale = ob.cscale
return self._eval


Expand Down
Loading

0 comments on commit 9ec7ba7

Please sign in to comment.