Skip to content

Commit

Permalink
Using np.testing.assert_* to print more useful info when asserts fail
Browse files Browse the repository at this point in the history
  • Loading branch information
landreman committed Jan 13, 2025
1 parent 1f085b5 commit 35553ae
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/geo/test_curve_optimizable.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def subtest_curve_length_optimisation(self, rotated):
print(' Final curve length: ', obj.J())
print(' Expected final length: ', 2 * np.pi * x0[0])
print(' objective function: ', prob.objective())
assert abs(obj.J() - 2 * np.pi * x0[0]) < 1e-8
np.testing.assert_allclose(obj.J(), 2 * np.pi * x0[0], rtol=0, atol=1e-8)

def test_curve_first_derivative(self):
for rotated in [True, False]:
Expand Down
2 changes: 1 addition & 1 deletion tests/geo/test_pm_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_Bn(self):
Nnorms = np.ravel(np.sqrt(np.sum(s.normal() ** 2, axis=-1)))
Ngrid = nphi * ntheta
Bn_Am = (pm_opt.A_obj.dot(pm_opt.m) - pm_opt.b_obj) * np.sqrt(Ngrid / Nnorms)
assert np.allclose(Bn_Am.reshape(nphi, ntheta), np.sum((bs.B() + b_dipole.B()).reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))
np.testing.assert_allclose(Bn_Am.reshape(nphi, ntheta), np.sum((bs.B() + b_dipole.B()).reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2), atol=1e-15)

# check <Bn>
B_opt = np.mean(np.abs(pm_opt.A_obj.dot(pm_opt.m) - pm_opt.b_obj) * np.sqrt(Ngrid / Nnorms))
Expand Down
2 changes: 1 addition & 1 deletion tests/geo/test_surface_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def taylor_test1(f, df, x, epsilons=None, direction=None):
dfest = (fpluseps-fminuseps)/(2*eps)
err = np.linalg.norm(dfest - dfx)
print("taylor test1: ", err, err/err_old)
assert err < 1e-9 or err < 0.3 * err_old
np.testing.assert_array_less(err, max(1e-9, 0.3 * err_old))
err_old = err
print("###################################################################")

Expand Down

0 comments on commit 35553ae

Please sign in to comment.