Skip to content

Commit

Permalink
split model pred loading from CSV into new module matbench_discovery/…
Browse files Browse the repository at this point in the history
…preds.py
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 24f6868 commit 0e9e5dc
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 74 deletions.
36 changes: 6 additions & 30 deletions matbench_discovery/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""Centralize data-loading and computing metrics for plotting scripts"""

from __future__ import annotations

from collections.abc import Sequence
Expand All @@ -8,7 +6,12 @@
import pandas as pd
from sklearn.metrics import r2_score

from matbench_discovery.data import load_df_wbm_preds
"""Functions to classify energy above convex hull predictions as true/false
positive/negative and compute performance metrics.
"""

__author__ = "Janosh Riebesell"
__date__ = "2023-02-01"


def classify_stable(
Expand Down Expand Up @@ -98,30 +101,3 @@ def stable_metrics(
RMSE=((true - pred) ** 2).mean() ** 0.5,
R2=r2_score(true, pred),
)


models = sorted(
"Wrenformer, CGCNN, Voronoi Random Forest, MEGNet, M3GNet + MEGNet, "
"BOWSR + MEGNet".split(", ")
)
e_form_col = "e_form_per_atom_mp2020_corrected"
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
each_pred_col = "e_above_hull_pred"

df_wbm = load_df_wbm_preds(models).round(3)

for col in [e_form_col, each_true_col]:
assert col in df_wbm, f"{col=} not in {list(df_wbm)=}"


df_metrics = pd.DataFrame()
for model in models:
df_metrics[model] = stable_metrics(
df_wbm[each_true_col],
df_wbm[each_true_col] + df_wbm[e_form_col] - df_wbm[model],
)

assert df_metrics.T.MAE.between(0, 0.2).all(), "MAE not in range"
assert df_metrics.T.R2.between(0.1, 1).all(), "R2 not in range"
assert df_metrics.T.RMSE.between(0, 0.25).all(), "RMSE not in range"
assert df_metrics.isna().sum().sum() == 0, "NaNs in metrics"
37 changes: 37 additions & 0 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import pandas as pd

from matbench_discovery.data import load_df_wbm_preds
from matbench_discovery.metrics import stable_metrics

"""Centralize data-loading and computing metrics for plotting scripts"""

__author__ = "Janosh Riebesell"
__date__ = "2023-02-04"

models = sorted(
"Wrenformer, CGCNN, Voronoi Random Forest, MEGNet, M3GNet + MEGNet, "
"BOWSR + MEGNet".split(", ")
)
e_form_col = "e_form_per_atom_mp2020_corrected"
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
each_pred_col = "e_above_hull_pred"

df_wbm = load_df_wbm_preds(models).round(3)

for col in [e_form_col, each_true_col]:
assert col in df_wbm, f"{col=} not in {list(df_wbm)=}"


df_metrics = pd.DataFrame()
for model in models:
df_metrics[model] = stable_metrics(
df_wbm[each_true_col],
df_wbm[each_true_col] + df_wbm[e_form_col] - df_wbm[model],
)

assert df_metrics.T.MAE.between(0, 0.2).all(), "MAE not in range"
assert df_metrics.T.R2.between(0.1, 1).all(), "R2 not in range"
assert df_metrics.T.RMSE.between(0, 0.25).all(), "RMSE not in range"
assert df_metrics.isna().sum().sum() == 0, "NaNs in metrics"
2 changes: 1 addition & 1 deletion scripts/cumulative_clf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, STATIC, today
from matbench_discovery.metrics import df_wbm, e_form_col, each_true_col, models
from matbench_discovery.plots import cumulative_precision_recall
from matbench_discovery.preds import df_wbm, e_form_col, each_true_col, models

__author__ = "Janosh Riebesell, Rhys Goodall"
__date__ = "2022-12-04"
Expand Down
9 changes: 2 additions & 7 deletions scripts/hist_classified_stable_vs_hull_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, today
from matbench_discovery.metrics import (
df_wbm,
e_form_col,
each_pred_col,
each_true_col,
stable_metrics,
)
from matbench_discovery.metrics import stable_metrics
from matbench_discovery.plots import WhichEnergy, hist_classified_stable_vs_hull_dist
from matbench_discovery.preds import df_wbm, e_form_col, each_pred_col, each_true_col

__author__ = "Rhys Goodall, Janosh Riebesell"
__date__ = "2022-06-18"
Expand Down
9 changes: 2 additions & 7 deletions scripts/hist_classified_stable_vs_hull_dist_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, today
from matbench_discovery.metrics import (
df_wbm,
e_form_col,
each_pred_col,
each_true_col,
stable_metrics,
)
from matbench_discovery.metrics import stable_metrics
from matbench_discovery.plots import (
Backend,
WhichEnergy,
hist_classified_stable_vs_hull_dist,
plt,
)
from matbench_discovery.preds import df_wbm, e_form_col, each_pred_col, each_true_col

__author__ = "Rhys Goodall, Janosh Riebesell"
__date__ = "2022-08-25"
Expand Down
4 changes: 2 additions & 2 deletions scripts/hist_classified_stable_vs_hull_dist_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from pymatviz.utils import save_fig

from matbench_discovery import STATIC, today
from matbench_discovery.metrics import (
from matbench_discovery.plots import Backend, hist_classified_stable_vs_hull_dist, plt
from matbench_discovery.preds import (
df_metrics,
df_wbm,
e_form_col,
each_true_col,
models,
)
from matbench_discovery.plots import Backend, hist_classified_stable_vs_hull_dist, plt

__author__ = "Janosh Riebesell"
__date__ = "2022-12-01"
Expand Down
6 changes: 3 additions & 3 deletions scripts/prc_roc_curves_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from tqdm import tqdm

from matbench_discovery import FIGS, today
from matbench_discovery.metrics import (
from matbench_discovery.metrics import stable_metrics
from matbench_discovery.plots import pio
from matbench_discovery.preds import (
df_wbm,
e_form_col,
each_pred_col,
each_true_col,
models,
stable_metrics,
)
from matbench_discovery.plots import pio

__author__ = "Janosh Riebesell"
__date__ = "2023-01-30"
Expand Down
2 changes: 1 addition & 1 deletion scripts/rolling_mae_vs_hull_dist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# %%
from matbench_discovery import FIGS, today
from matbench_discovery.metrics import df_metrics, df_wbm, e_form_col, each_true_col
from matbench_discovery.plots import rolling_mae_vs_hull_dist
from matbench_discovery.preds import df_metrics, df_wbm, e_form_col, each_true_col

__author__ = "Rhys Goodall, Janosh Riebesell"
__date__ = "2022-06-18"
Expand Down
2 changes: 1 addition & 1 deletion scripts/rolling_mae_vs_hull_dist_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, STATIC, today
from matbench_discovery.metrics import df_metrics, df_wbm, e_form_col, each_true_col
from matbench_discovery.plots import Backend, rolling_mae_vs_hull_dist
from matbench_discovery.preds import df_metrics, df_wbm, e_form_col, each_true_col

__author__ = "Rhys Goodall, Janosh Riebesell"
__date__ = "2022-06-18"
Expand Down
2 changes: 1 addition & 1 deletion scripts/rolling_mae_vs_hull_dist_wbm_batches.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# %%
from matbench_discovery import FIGS, today
from matbench_discovery.metrics import df_wbm, e_form_col, each_true_col
from matbench_discovery.plots import plt, rolling_mae_vs_hull_dist
from matbench_discovery.preds import df_wbm, e_form_col, each_true_col

__author__ = "Rhys Goodall, Janosh Riebesell"
__date__ = "2022-06-18"
Expand Down
7 changes: 3 additions & 4 deletions scripts/scatter_e_above_hull_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from pymatviz.utils import add_identity_line, save_fig

from matbench_discovery import FIGS, STATIC, today
from matbench_discovery.metrics import (
classify_stable,
from matbench_discovery.metrics import classify_stable, stable_metrics
from matbench_discovery.plots import clf_color_map, clf_labels, px
from matbench_discovery.preds import (
df_wbm,
e_form_col,
each_pred_col,
each_true_col,
models,
stable_metrics,
)
from matbench_discovery.plots import clf_color_map, clf_labels, px

__author__ = "Janosh Riebesell"
__date__ = "2022-11-28"
Expand Down
8 changes: 4 additions & 4 deletions site/src/routes/how-to-contribute/+page.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,17 @@ and place the above-listed files there. The file structure should look like this
```txt
matbench-discovery-root
└── models
└── <model name>
└── <model_name>
├── metadata.yml
├── <yyyy-mm-dd>-<model_name>-preds.(json|csv).gz
├── test_<model_name>.py
├── readme.md # optional
└── train_<model_name>.py # optional
├── readme.md # optional
└── train_<model_name>.py # optional
```

You can include arbitrary other supporting files like metadata and model features (below 10MB to keep `git clone` time low) if they are needed to run the model or help others reproduce your results. For larger files, please upload to [Figshare](https://figshare.com) or similar and link them somewhere in your files.

### Step 3: Create a PR to the [Matbench Discovery repo](https://github.com/janosh/matbench-discovery)
### Step 3: Open a PR to the [Matbench Discovery repo](https://github.com/janosh/matbench-discovery)

Commit your files to the repo on a branch called `<model_name>` and create a pull request (PR) to the Matbench repository.

Expand Down
16 changes: 15 additions & 1 deletion tests/test_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
from pymatgen.analysis.phase_diagram import PDEntry
from pymatgen.core import Lattice, Structure
from pymatgen.entries.computed_entries import ComputedEntry, Entry
from pytest import approx

from matbench_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries
from matbench_discovery.energy import (
get_e_form_per_atom,
get_elemental_ref_entries,
mp_elem_reference_entries,
mp_elemental_ref_energies,
)

dummy_struct = Structure(
lattice=Lattice.cubic(5),
Expand Down Expand Up @@ -49,3 +55,11 @@ def test_get_elemental_ref_entries(
expected = {"Fe": constructor(*entries[2]), "O": constructor(*entries[3])}

assert elemental_ref_entries == expected


def test_mp_ref_energies() -> None:
"""Test MP elemental reference energies are in sync with PDEntries saved to disk."""
for key, val in mp_elemental_ref_energies.items():
actual = mp_elem_reference_entries[key].energy_per_atom
assert actual == approx(val, abs=1e-3), f"{key=}"
assert actual == approx(val, abs=1e-3), f"{key=}"
12 changes: 0 additions & 12 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
import pytest
from pytest import approx

from matbench_discovery.energy import (
mp_elem_reference_entries,
mp_elemental_ref_energies,
)
from matbench_discovery.metrics import classify_stable, stable_metrics


Expand Down Expand Up @@ -83,11 +79,3 @@ def test_stable_metrics() -> None:
# test stable_metrics docstring is up to date, all returned metrics should be listed
assert stable_metrics.__doc__ # for mypy
assert all(key in stable_metrics.__doc__ for key in metrics)


def test_mp_ref_energies() -> None:
"""Test MP elemental reference energies are in sync with PDEntries saved to disk."""
for key, val in mp_elemental_ref_energies.items():
actual = mp_elem_reference_entries[key].energy_per_atom
assert actual == approx(val, abs=1e-3), f"{key=}"
assert actual == approx(val, abs=1e-3), f"{key=}"

0 comments on commit 0e9e5dc

Please sign in to comment.