Skip to content

Commit

Permalink
add matbench_discovery/metrics.py to centralize computing metrics for…
Browse files Browse the repository at this point in the history
… plotting scripts

remove np.random.seed(0) from test_stable_metrics for increased randomness
test regression metrics against sklearn in test_stable_metrics
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 248a79b commit 500e670
Show file tree
Hide file tree
Showing 17 changed files with 309 additions and 318 deletions.
2 changes: 0 additions & 2 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
today = timestamp.split("@")[0]

# load docs, repo, package URLs from package.json
print(f"{ROOT=}")

with open(f"{ROOT}/site/package.json") as file:
pkg = json.load(file)
pypi_keys_to_npm = dict(Docs="homepage", Repo="repository", Package="package")
Expand Down
5 changes: 3 additions & 2 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@

def as_dict_handler(obj: Any) -> dict[str, Any] | None:
"""Pass this to json.dump(default=) or as pandas.to_json(default_handler=) to
convert Python classes with a as_dict() method to dictionaries on serialization.
Objects without a as_dict() method are replaced with None in the serialized data.
serialize Python classes with as_dict(). Warning: Objects without a as_dict() method
are replaced with None in the serialized data.
"""
try:
return obj.as_dict() # all MSONable objects implement as_dict()
Expand Down Expand Up @@ -144,6 +144,7 @@ def load_train_test(
"Wrenformer": "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv",
"MEGNet": "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv",
"M3GNet": "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv",
"M3GNet MEGNet": "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv",
"BOWSR MEGNet": "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv",
}

Expand Down
91 changes: 0 additions & 91 deletions matbench_discovery/energy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import itertools
from collections.abc import Sequence

import numpy as np
import pandas as pd
from pymatgen.analysis.phase_diagram import Entry, PDEntry
from pymatgen.core import Composition
from pymatgen.util.typing import EntryLike
from sklearn.metrics import r2_score
from tqdm import tqdm

from matbench_discovery import ROOT
Expand Down Expand Up @@ -120,92 +118,3 @@ def get_e_form_per_atom(
form_energy = energy - sum(comp[el] * e_refs[str(el)] for el in comp)

return form_energy / comp.num_atoms


def classify_stable(
e_above_hull_true: pd.Series,
e_above_hull_pred: pd.Series,
stability_threshold: float = 0,
) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]:
"""Classify model stability predictions as true/false positive/negatives (usually
w.r.t DFT-ground truth labels). All energies are assumed to be in eV/atom
(but shouldn't really matter as long as they're consistent).
Args:
e_above_hull_true (pd.Series): Ground truth energy above convex hull values.
e_above_hull_pred (pd.Series): Model predicted energy above convex hull values.
stability_threshold (float, optional): Maximum energy above convex hull for a
material to still be considered stable. Usually 0, 0.05 or 0.1. Defaults to
0, meaning a material has to be directly on the hull to be called stable.
Negative values mean a material has to pull the known hull down by that
amount to count as stable. Few materials lie below the known hull, so only
negative values very close to 0 make sense.
Returns:
tuple[TP, FN, FP, TN]: Indices as pd.Series for true positives,
false negatives, false positives and true negatives (in this order).
"""
actual_pos = e_above_hull_true <= stability_threshold
actual_neg = e_above_hull_true > stability_threshold
model_pos = e_above_hull_pred <= stability_threshold
model_neg = e_above_hull_pred > stability_threshold

true_pos = actual_pos & model_pos
false_neg = actual_pos & model_neg
false_pos = actual_neg & model_pos
true_neg = actual_neg & model_neg

return true_pos, false_neg, false_pos, true_neg


def stable_metrics(
true: Sequence[float], pred: Sequence[float], stability_threshold: float = 0
) -> dict[str, float]:
"""
Get a dictionary of stability prediction metrics. Mostly binary classification
metrics, but also MAE, RMSE and R2.
Args:
true (list[float]): true energy values
pred (list[float]): predicted energy values
stability_threshold (float): Where to place stability threshold relative to
convex hull in eV/atom, usually 0 or 0.1 eV. Defaults to 0.
Note: Could be replaced by sklearn.metrics.classification_report() which takes
binary labels. I.e. classification_report(true > 0, pred > 0, output_dict=True)
should give equivalent results.
Returns:
dict[str, float]: dictionary of classification metrics with keys DAF, Precision,
Recall, Accuracy, F1, TPR, FPR, TNR, FNR, MAE, RMSE, R2.
"""
true_pos, false_neg, false_pos, true_neg = classify_stable(
true, pred, stability_threshold
)

n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
sum, (true_pos, false_pos, true_neg, false_neg)
)

n_total_pos = n_true_pos + n_false_neg
prevalence = n_total_pos / len(true) # null rate
precision = n_true_pos / (n_true_pos + n_false_pos)
recall = n_true_pos / n_total_pos

is_nan = np.isnan(true) | np.isnan(pred)
true, pred = np.array(true)[~is_nan], np.array(pred)[~is_nan]

return dict(
DAF=precision / prevalence,
Precision=precision,
Recall=recall,
Accuracy=(n_true_pos + n_true_neg) / len(true),
F1=2 * (precision * recall) / (precision + recall),
TPR=n_true_pos / (n_true_pos + n_false_neg),
FPR=n_false_pos / (n_true_neg + n_false_pos),
TNR=n_true_neg / (n_true_neg + n_false_pos),
FNR=n_false_neg / (n_true_pos + n_false_neg),
MAE=np.abs(true - pred).mean(),
RMSE=((true - pred) ** 2).mean() ** 0.5,
R2=r2_score(true, pred),
)
125 changes: 125 additions & 0 deletions matbench_discovery/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Centralize data-loading and computing metrics for plotting scripts"""

from collections.abc import Sequence

import numpy as np
import pandas as pd
from sklearn.metrics import r2_score

from matbench_discovery.data import load_df_wbm_preds


def classify_stable(
e_above_hull_true: pd.Series,
e_above_hull_pred: pd.Series,
stability_threshold: float | None = 0,
) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]:
"""Classify model stability predictions as true/false positive/negatives (usually
w.r.t DFT-ground truth labels). All energies are assumed to be in eV/atom
(but shouldn't really matter as long as they're consistent).
Args:
e_above_hull_true (pd.Series): Ground truth energy above convex hull values.
e_above_hull_pred (pd.Series): Model predicted energy above convex hull values.
stability_threshold (float | None, optional): Maximum energy above convex hull for a
material to still be considered stable. Usually 0, 0.05 or 0.1. Defaults to
0, meaning a material has to be directly on the hull to be called stable.
Negative values mean a material has to pull the known hull down by that
amount to count as stable. Few materials lie below the known hull, so only
negative values very close to 0 make sense.
Returns:
tuple[TP, FN, FP, TN]: Indices as pd.Series for true positives,
false negatives, false positives and true negatives (in this order).
"""
actual_pos = e_above_hull_true <= (stability_threshold or 0) # guard against None
actual_neg = e_above_hull_true > (stability_threshold or 0)
model_pos = e_above_hull_pred <= (stability_threshold or 0)
model_neg = e_above_hull_pred > (stability_threshold or 0)

true_pos = actual_pos & model_pos
false_neg = actual_pos & model_neg
false_pos = actual_neg & model_pos
true_neg = actual_neg & model_neg

return true_pos, false_neg, false_pos, true_neg


def stable_metrics(
true: Sequence[float], pred: Sequence[float], stability_threshold: float = 0
) -> dict[str, float]:
"""
Get a dictionary of stability prediction metrics. Mostly binary classification
metrics, but also MAE, RMSE and R2.
Args:
true (list[float]): true energy values
pred (list[float]): predicted energy values
stability_threshold (float): Where to place stability threshold relative to
convex hull in eV/atom, usually 0 or 0.1 eV. Defaults to 0.
Note: Could be replaced by sklearn.metrics.classification_report() which takes
binary labels. I.e. classification_report(true > 0, pred > 0, output_dict=True)
should give equivalent results.
Returns:
dict[str, float]: dictionary of classification metrics with keys DAF, Precision,
Recall, Accuracy, F1, TPR, FPR, TNR, FNR, MAE, RMSE, R2.
"""
true_pos, false_neg, false_pos, true_neg = classify_stable(
true, pred, stability_threshold
)

n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
sum, (true_pos, false_pos, true_neg, false_neg)
)

n_total_pos = n_true_pos + n_false_neg
prevalence = n_total_pos / len(true) # null rate
precision = n_true_pos / (n_true_pos + n_false_pos)
recall = n_true_pos / n_total_pos

is_nan = np.isnan(true) | np.isnan(pred)
true, pred = np.array(true)[~is_nan], np.array(pred)[~is_nan]

return dict(
DAF=precision / prevalence,
Precision=precision,
Recall=recall,
Accuracy=(n_true_pos + n_true_neg) / len(true),
F1=2 * (precision * recall) / (precision + recall),
TPR=n_true_pos / n_total_pos,
FPR=n_false_pos / (n_true_neg + n_false_pos),
TNR=n_true_neg / (n_true_neg + n_false_pos),
FNR=n_false_neg / n_total_pos,
MAE=np.abs(true - pred).mean(),
RMSE=((true - pred) ** 2).mean() ** 0.5,
R2=r2_score(true, pred),
)


models = sorted(
"Wrenformer, CGCNN, Voronoi Random Forest, MEGNet, M3GNet "
"MEGNet, BOWSR MEGNet".split(", ")
)
e_form_col = "e_form_per_atom_mp2020_corrected"
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
each_pred_col = "e_above_hull_pred"

df_wbm = load_df_wbm_preds(models).round(3)

for col in [e_form_col, each_true_col]:
assert col in df_wbm, f"{col=} not in {list(df_wbm)=}"


df_metrics = pd.DataFrame()
for model in models:
df_metrics[model] = stable_metrics(
df_wbm[each_true_col],
df_wbm[each_true_col] + df_wbm[e_form_col] - df_wbm[model],
)

assert df_metrics.T.MAE.between(0, 0.2).all(), "MAE not in range"
assert df_metrics.T.R2.between(0.1, 1).all(), "R2 not in range"
assert df_metrics.T.RMSE.between(0, 0.25).all(), "RMSE not in range"
assert df_metrics.isna().sum().sum() == 0, "NaNs in metrics"
4 changes: 2 additions & 2 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import wandb
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

from matbench_discovery.energy import classify_stable
from matbench_discovery.metrics import classify_stable

__author__ = "Janosh Riebesell"
__date__ = "2022-08-05"
Expand Down Expand Up @@ -102,7 +102,7 @@ def hist_classified_stable_vs_hull_dist(
each_pred_col: str,
ax: plt.Axes = None,
which_energy: WhichEnergy = "true",
stability_threshold: float = 0,
stability_threshold: float | None = 0,
x_lim: tuple[float | None, float | None] = (-0.7, 0.7),
rolling_acc: float | None = 0.02,
backend: Backend = "plotly",
Expand Down
2 changes: 1 addition & 1 deletion scripts/compile_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from matbench_discovery import FIGS, MODELS, WANDB_PATH, today
from matbench_discovery.data import PRED_FILENAMES, load_df_wbm_preds
from matbench_discovery.energy import stable_metrics
from matbench_discovery.metrics import stable_metrics
from matbench_discovery.plots import px

__author__ = "Janosh Riebesell"
Expand Down
34 changes: 9 additions & 25 deletions scripts/cumulative_clf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,22 @@
import pandas as pd
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, today
from matbench_discovery.data import load_df_wbm_preds
from matbench_discovery import FIGS, STATIC, today
from matbench_discovery.metrics import df_wbm, e_form_col, each_true_col, models
from matbench_discovery.plots import cumulative_precision_recall

__author__ = "Janosh Riebesell, Rhys Goodall"
__date__ = "2022-12-04"


# %%
models = (
"CGCNN, Voronoi Random Forest, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet"
).split(", ")

df_wbm = load_df_wbm_preds(models).round(3)

# df_wbm.columns = [f"{col}_e_form" if col in models else col for col in df_wbm]
e_form_col = "e_form_per_atom_mp2020_corrected"
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"


# %%
df_e_above_hull_pred = pd.DataFrame()
for model in models:
e_above_hul_pred = df_wbm[e_above_hull_col] + df_wbm[model] - df_wbm[e_form_col]
e_above_hul_pred = df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col]
df_e_above_hull_pred[model] = e_above_hul_pred

fig, df_metric = cumulative_precision_recall(
e_above_hull_true=df_wbm[e_above_hull_col],
e_above_hull_true=df_wbm[each_true_col],
df_preds=df_e_above_hull_pred,
project_end_point="xy",
backend=(backend := "plotly"),
Expand All @@ -42,11 +30,7 @@
# fig.suptitle(title)
fig.text(0.5, -0.08, xlabel, ha="center", fontdict={"size": 16})
if backend == "plotly":
# place legend in lower right corner
fig.update_layout(
# title=title,
legend=dict(yanchor="bottom", y=0.02, xanchor="right", x=0.9),
)
fig.layout.legend.update(x=0.01, y=0) # , title=title
fig.layout.height = 500
fig.add_annotation(
x=0.5,
Expand All @@ -69,7 +53,7 @@
assert isinstance(trace.y[0], float)
trace.y = [round(y, 3) for y in trace.y]

img_path = f"{FIGS}/{today}-cumulative-clf-metrics"
# save_fig(fig, f"{img_path}.pdf")
save_fig(fig, f"{img_path}.svelte")
# save_fig(fig, f"{img_path}.webp", scale=3)
img_path = f"{today}-cumulative-clf-metrics"
# save_fig(fig, f"{STATIC}/{img_path}.pdf")
save_fig(fig, f"{FIGS}/{img_path}.svelte")
save_fig(fig, f"{STATIC}/{img_path}.webp", scale=3)
20 changes: 9 additions & 11 deletions scripts/hist_classified_stable_vs_hull_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, today
from matbench_discovery.data import load_df_wbm_preds
from matbench_discovery.energy import stable_metrics
from matbench_discovery.metrics import (
df_wbm,
e_form_col,
each_pred_col,
each_true_col,
stable_metrics,
)
from matbench_discovery.plots import WhichEnergy, hist_classified_stable_vs_hull_dist

__author__ = "Rhys Goodall, Janosh Riebesell"
Expand All @@ -20,13 +25,6 @@

# %%
model_name = "Wrenformer"
df_wbm = load_df_wbm_preds(models=[model_name]).round(3)


# %%
e_form_col = "e_form_per_atom_mp2020_corrected"
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
each_pred_col = "e_above_hull_pred"
which_energy: WhichEnergy = "true"
# std_factor=0,+/-1,+/-2,... changes the criterion for material stability to
# energy+std_factor*std. energy+std means predicted energy plus the model's uncertainty
Expand Down Expand Up @@ -68,5 +66,5 @@

# %%
img_path = f"{FIGS}/{today}-wren-wbm-hull-dist-hist-{which_energy=}"
# save_fig(ax, f"{img_path}.pdf")
save_fig(fig, f"{img_path}.html")
# save_fig(fig, f"{img_path}.svelte")
save_fig(fig, f"{img_path}.webp")
Loading

0 comments on commit 500e670

Please sign in to comment.