Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative approach for stop_when_dft_decayed #1847

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/adjoint/optimization_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
fcen=None,
df=None,
nf=None,
decay_by=1e-11,
decay_by=7.5e-7,
decimation_factor=0,
minimum_run_time=0,
maximum_run_time=None,
Expand Down
2 changes: 1 addition & 1 deletion python/adjoint/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
monitors: List[EigenmodeCoefficient],
design_regions: List[DesignRegion],
frequencies: List[float],
dft_threshold: float = 1e-11,
dft_threshold: float = 6e-7,
minimum_run_time: float = 0,
maximum_run_time: float = onp.inf,
until_after_sources: bool = True,
Expand Down
16 changes: 5 additions & 11 deletions python/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4545,7 +4545,7 @@ def _stop(sim):

return _stop

def stop_when_dft_decayed(tol=1e-11, minimum_run_time=0, maximum_run_time=None):
def stop_when_dft_decayed(tol=7.5e-7, minimum_run_time=0, maximum_run_time=None):
"""
Return a `condition` function, suitable for passing to `Simulation.run` as the `until`
or `until_after_sources` parameter, that checks the `Simulation`'s DFT objects every $t$
Expand All @@ -4557,7 +4557,7 @@ def stop_when_dft_decayed(tol=1e-11, minimum_run_time=0, maximum_run_time=None):
"""

# Record data in closure so that we can persistently edit
closure = {'previous_fields':0, 't0':0, 'dt':0, 'maxchange':0}
closure = {'t0':0, 'dt':0, 'maxchange':0}
def _stop(_sim):
if _sim.fields.t == 0:
closure['dt'] = max(1/_sim.fields.dft_maxfreq()/_sim.fields.dt,_sim.fields.max_decimation())
Expand All @@ -4566,17 +4566,11 @@ def _stop(_sim):
elif _sim.fields.t <= closure['dt'] + closure['t0']:
return False
else:
previous_fields = closure['previous_fields']
current_fields = _sim.fields.dft_norm()
change = np.abs(previous_fields-current_fields)
change = _sim.fields.dft_time_fields_norm()
closure['maxchange'] = max(closure['maxchange'],change)

if previous_fields == 0:
closure['previous_fields'] = current_fields
return False

closure['previous_fields'] = current_fields
closure['t0'] = _sim.fields.t
if closure['maxchange'] == 0:
return False
if verbosity.meep > 1:
fmt = "DFT fields decay(t = {0:0.2f}): {1:0.4e}"
print(fmt.format(_sim.meep_time(), np.real(change/closure['maxchange'])))
Expand Down
8 changes: 5 additions & 3 deletions python/tests/test_adjoint_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def J(mode_mon):
objective_functions=J,
objective_arguments=obj_list,
design_regions=[matgrid_region],
frequencies=frequencies)
frequencies=frequencies,
minimum_run_time=150)

f, dJ_du = opt([design_params])

Expand Down Expand Up @@ -255,7 +256,8 @@ def J(dft_mon):
objective_functions=J,
objective_arguments=obj_list,
design_regions=[matgrid_region],
frequencies=frequencies)
frequencies=frequencies,
minimum_run_time=150)

f, dJ_du = opt([design_params])

Expand Down Expand Up @@ -470,7 +472,7 @@ def test_complex_fields(self):

## compare objective results
print("Ez2 -- adjoint solver: {}, traditional simulation: {}".format(adjsol_obj,Ez2_unperturbed))
self.assertClose(adjsol_obj,Ez2_unperturbed,epsilon=1e-6)
self.assertClose(adjsol_obj,Ez2_unperturbed,epsilon=2e-6)

## compute perturbed |Ez|^2
Ez2_perturbed = forward_simulation_complex_fields(p+dp, frequencies)
Expand Down
40 changes: 26 additions & 14 deletions src/dft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,36 +282,48 @@ void dft_chunk::update_dft(double time) {
}
}

/* Return the L2 norm of the DFTs themselves. This is useful
/* Return the L2 norm of all of the fields used to update DFTs. This is useful
to check whether the simulation is finished (whether all relevant fields have decayed).
(Collective operation.) */
double fields::dft_norm() {
(Collective operation.)
In this case, we are only looking at the *update* (e.g. independent
of the phase term in the DTFT inner product). This is a more conservative approach
that by definition looks at *all* the simulation's frequencies (not just the ones
we care about).
*/

static double sqr(std::complex<realnum> x) { return (x*std::conj(x)).real(); }

double fields::dft_time_fields_norm() {
am_now_working_on(Other);
double sum = 0.0;
for (int i = 0; i < num_chunks; i++)
if (chunks[i]->is_mine()) sum += chunks[i]->dft_norm2();
if (chunks[i]->is_mine()) sum += chunks[i]->dft_time_fields_norm2();
finished_working();
return std::sqrt(sum_to_all(sum));
}

double fields_chunk::dft_norm2() const {
double fields_chunk::dft_time_fields_norm2() const {
double sum = 0.0;
for (dft_chunk *cur = dft_chunks; cur; cur = cur->next_in_chunk)
sum += cur->norm2();
sum += cur->dft_time_fields_norm2();
return sum;
}

static double sqr(std::complex<realnum> x) { return (x*std::conj(x)).real(); }

double dft_chunk::norm2() const {
double dft_chunk::dft_time_fields_norm2() const {
if (!fc->f[c][0]) return 0.0;
int numcmp = fc->f[c][1] ? 2 : 1;
double sum = 0.0;
size_t idx_dft = 0;
const int Nomega = omega.size();
LOOP_OVER_IVECS(fc->gv, is, ie, idx) {
for (int i = 0; i < Nomega; ++i)
sum += sqr(dft[Nomega * idx_dft + i]);
idx_dft++;
if (avg2)
for (int cmp = 0; cmp < numcmp; ++cmp)
sum += sqr(0.25 * (fc->f[c][cmp][idx] + fc->f[c][cmp][idx + avg1] +
fc->f[c][cmp][idx + avg2] + fc->f[c][cmp][idx + (avg1 + avg2)]));
else if (avg1)
for (int cmp = 0; cmp < numcmp; ++cmp)
sum += sqr(0.5 * (fc->f[c][cmp][idx] + fc->f[c][cmp][idx + avg1]));
else
for (int cmp = 0; cmp < numcmp; ++cmp)
sum += sqr(fc->f[c][cmp][idx]);
}
return sum;
}
Expand Down
6 changes: 3 additions & 3 deletions src/meep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ class dft_chunk {
~dft_chunk();

void update_dft(double time);
double norm2() const;
double dft_time_fields_norm2() const;
double maxomega() const;

void scale_dft(std::complex<double> scale);
Expand Down Expand Up @@ -1574,7 +1574,7 @@ class fields_chunk {
void initialize_with_nth_tm(int n, double kz);
// dft.cpp
void update_dfts(double timeE, double timeH, int current_step);
double dft_norm2() const;
double dft_time_fields_norm2() const;
double dft_maxfreq() const;
int max_decimation() const;

Expand Down Expand Up @@ -2004,7 +2004,7 @@ class fields {
dft_chunk *add_dft(const volume_list *where, const std::vector<double> &freq,
bool include_dV = true);
void update_dfts();
double dft_norm();
double dft_time_fields_norm();
double dft_maxfreq() const;
int max_decimation() const;

Expand Down