diff --git a/src/simsopt/solve/mpi.py b/src/simsopt/solve/mpi.py index 0f6a6988b..86976b470 100644 --- a/src/simsopt/solve/mpi.py +++ b/src/simsopt/solve/mpi.py @@ -23,6 +23,7 @@ MPI = None from .._core.optimizable import Optimizable +from .._core.util import Struct from ..util.mpi import MpiPartition from .._core.finite_difference import MPIFiniteDifference from ..objectives.least_squares import LeastSquaresProblem @@ -209,8 +210,13 @@ def _f_proc0(x): x0 = np.copy(prob.x) logger.info("Using finite difference method implemented in " "SIMSOPT for evaluating gradient") - result = least_squares(_f_proc0, x0, jac=fd.jac, verbose=2, - **kwargs) + try: + result = least_squares(_f_proc0, x0, jac=fd.jac, verbose=2, + **kwargs) + except: + print("Failure on proc0_world") + result = Struct() + result.x = x0 else: leaders_action = lambda mpi, data: None @@ -230,8 +236,9 @@ def _f_proc0(x): if mpi.proc0_world: x = result.x - objective_file.close() - if save_residuals: + if objective_file is not None: + objective_file.close() + if save_residuals and residuals_file is not None: residuals_file.close() datalog_started = False diff --git a/tests/solve/test_mpi.py b/tests/solve/test_mpi.py index e98574714..d0bfb4b37 100755 --- a/tests/solve/test_mpi.py +++ b/tests/solve/test_mpi.py @@ -9,6 +9,7 @@ MPI = None from simsopt._core.optimizable import Optimizable +from simsopt._core import ObjectiveFailure from simsopt.objectives.least_squares import LeastSquaresProblem if MPI is not None: from simsopt.util.mpi import MpiPartition @@ -87,6 +88,12 @@ def f1(self): return_fn_map = {'f0': f0, 'f1': f1} +class FailingOptimizable(Optimizable): + def residuals(self): + raise ObjectiveFailure("foo") + return self.x - np.array([10, 9, 8, 7]) + + @unittest.skipIf(MPI is None, "Requires mpi4py") class MPISolveTests(unittest.TestCase): @@ -137,3 +144,18 @@ def test_parallel_optimization_with_grad(self): self.assertAlmostEqual(prob.x[0], 1) self.assertAlmostEqual(prob.x[1], 1) + def test_objective_failure_with_mpi(self): + """ + If the objective function fails on the first evaluation, make sure the code does not hang. + """ + with ScratchDir("."): + for ngroups in range(1, 4): + mpi = MpiPartition(ngroups) + + opt = FailingOptimizable(x0=np.array([5, 6, 7, 8.0])) + + prob = LeastSquaresProblem.from_tuples( + [(opt.residuals, 0, 1)] + ) + + least_squares_mpi_solve(prob, mpi, grad=True) \ No newline at end of file