From 35553ae44d38e99207ac54ade4011100fd2d4f42 Mon Sep 17 00:00:00 2001 From: Matt Landreman Date: Mon, 13 Jan 2025 12:30:01 -0500 Subject: [PATCH] Using np.testing.assert_* to print more useful info when asserts fail --- tests/geo/test_curve_optimizable.py | 2 +- tests/geo/test_pm_grid.py | 2 +- tests/geo/test_surface_objectives.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/geo/test_curve_optimizable.py b/tests/geo/test_curve_optimizable.py index fc8eb4606..4b2da0e22 100644 --- a/tests/geo/test_curve_optimizable.py +++ b/tests/geo/test_curve_optimizable.py @@ -71,7 +71,7 @@ def subtest_curve_length_optimisation(self, rotated): print(' Final curve length: ', obj.J()) print(' Expected final length: ', 2 * np.pi * x0[0]) print(' objective function: ', prob.objective()) - assert abs(obj.J() - 2 * np.pi * x0[0]) < 1e-8 + np.testing.assert_allclose(obj.J(), 2 * np.pi * x0[0], rtol=0, atol=1e-8) def test_curve_first_derivative(self): for rotated in [True, False]: diff --git a/tests/geo/test_pm_grid.py b/tests/geo/test_pm_grid.py index 0a847f64b..01ff60771 100644 --- a/tests/geo/test_pm_grid.py +++ b/tests/geo/test_pm_grid.py @@ -216,7 +216,7 @@ def test_Bn(self): Nnorms = np.ravel(np.sqrt(np.sum(s.normal() ** 2, axis=-1))) Ngrid = nphi * ntheta Bn_Am = (pm_opt.A_obj.dot(pm_opt.m) - pm_opt.b_obj) * np.sqrt(Ngrid / Nnorms) - assert np.allclose(Bn_Am.reshape(nphi, ntheta), np.sum((bs.B() + b_dipole.B()).reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)) + np.testing.assert_allclose(Bn_Am.reshape(nphi, ntheta), np.sum((bs.B() + b_dipole.B()).reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2), atol=1e-15) # check B_opt = np.mean(np.abs(pm_opt.A_obj.dot(pm_opt.m) - pm_opt.b_obj) * np.sqrt(Ngrid / Nnorms)) diff --git a/tests/geo/test_surface_objectives.py b/tests/geo/test_surface_objectives.py index 5abb41c3e..dbd690b26 100644 --- a/tests/geo/test_surface_objectives.py +++ b/tests/geo/test_surface_objectives.py @@ -28,7 +28,7 @@ def taylor_test1(f, df, x, epsilons=None, direction=None): dfest = (fpluseps-fminuseps)/(2*eps) err = np.linalg.norm(dfest - dfx) print("taylor test1: ", err, err/err_old) - assert err < 1e-9 or err < 0.3 * err_old + np.testing.assert_array_less(err, max(1e-9, 0.3 * err_old)) err_old = err print("###################################################################")