Skip to content

Commit

Permalink
avoid crash in alchemlyb.visualisation.plot_convergence with NaN (#319)
Browse files Browse the repository at this point in the history
- fix #318
- only plot final error estimate in  alchemlyb.visualisation.plot_convergence if it is not NaN

Co-authored-by: Zhiyi Wu <zwu@exscientia.co.uk>
  • Loading branch information
xiki-tempula and xiki-tempula authored Jun 1, 2023
1 parent 4e590cc commit 064c4fe
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
9 changes: 9 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ The rules for this file:
* release numbers follow "Semantic Versioning" https://semver.org

------------------------------------------------------------------------------
*/*/2023 xiki-tempula

* 2.1.0

Fixes
- Fix the case where visualisation.plot_convergence would fail when the final
error is NaN (issue #318, PR#317).


06/04/2023 xiki-tempula

* 2.0.1
Expand Down
16 changes: 16 additions & 0 deletions src/alchemlyb/tests/test_visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,22 @@ def test_plot_convergence(gmx_benzene_Coulomb_u_nk):
plt.close(ax.figure)


def test_plot_convergence_final_nan():
"""Test the case where the Error of the final estimate is NaN."""
df = pd.DataFrame(
data={
"Forward": [1, 2],
"Forward_Error": [np.nan, np.nan],
"Backward": [1, 2],
"Backward_Error": [np.nan, np.nan],
}
)
df.attrs = {"temperature": 300, "energy_unit": "kT"}
ax = plot_convergence(df)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close(ax.figure)


class Test_Units:
@staticmethod
@pytest.fixture()
Expand Down
15 changes: 8 additions & 7 deletions src/alchemlyb/visualisation/convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,14 @@ def plot_convergence(dataframe, units=None, final_error=None, ax=None):
if final_error is None:
final_error = backward_error[-1]

line0 = ax.fill_between(
[0, 1],
backward[-1] - final_error,
backward[-1] + final_error,
color="#D2B9D3",
zorder=1,
)
if np.isfinite(backward[-1]) and np.isfinite(final_error):
line0 = ax.fill_between(
[0, 1],
backward[-1] - final_error,
backward[-1] + final_error,
color="#D2B9D3",
zorder=1,
)
line1 = ax.errorbar(
f_ts,
forward,
Expand Down

0 comments on commit 064c4fe

Please sign in to comment.