Skip to content

Commit

Permalink
add plot_scripts/all_models_scatter.py
Browse files Browse the repository at this point in the history
centralize model preds loading into matbench_discovery/plot_scripts/__init__.py, now used in most plot scripts
add tests/test_plots_scripts.py
  • Loading branch information
janosh committed Jun 20, 2023
1 parent bac8551 commit 0236c36
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 143 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# %%
import pandas as pd
import pymatviz

from matbench_discovery import ROOT, today
from matbench_discovery.plot_scripts import df_wbm
from matbench_discovery.plot_scripts import load_df_wbm_with_preds
from matbench_discovery.plots import (
StabilityCriterion,
WhichEnergy,
Expand All @@ -27,59 +24,25 @@


# %%
dfs = {}
dfs["wren"] = pd.read_csv(
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
).set_index("material_id")
dfs["m3gnet"] = pd.read_json(
f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
).set_index("material_id")
dfs["wrenformer"] = pd.read_csv(
f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
).set_index("material_id")
dfs["bowsr_megnet"] = pd.read_json(
f"{ROOT}/models/bowsr/2022-11-22-bowsr-megnet-wbm-IS2RE.json.gz"
).set_index("material_id")


# %%
pred_col = "e_form_per_atom_pred"
target_col = "e_form_per_atom"
if "wren" in dfs:
df = dfs["wren"]
pred_cols = df.filter(regex=r"_pred_\d").columns
# make sure we average the expected number of ensemble member predictions
assert len(pred_cols) == 10
df[pred_col] = df[pred_cols].mean(axis=1)
if "m3gnet" in dfs:
df = dfs["m3gnet"]
df[pred_col] = df.e_form_per_atom_m3gnet
if "bowsr_megnet" in dfs:
df = dfs["bowsr_megnet"]
df[pred_col] = df.e_form_per_atom_bowsr_megnet
if "wrenformer" in dfs:
pred_col = "e_form_per_atom_mp2020_corrected_pred_ens"
df_wbm = load_df_wbm_with_preds(models="Wren Wrenformer".split()).round(3)
target_col = "e_form_per_atom_mp2020_corrected"
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"


# %%
which_energy: WhichEnergy = "true"
stability_crit: StabilityCriterion = "energy"
fig, axs = plt.subplots(2, 3, figsize=(18, 9))

model_name = "wrenformer"
df = dfs[model_name]

df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
df[target_col] = df_wbm.e_form_per_atom_mp2020_corrected # e_form targets

model_name = "Wrenformer"

for batch_idx, ax in zip(range(1, 6), axs.flat):
batch_df = df[df.index.str.startswith(f"wbm-step-{batch_idx}-")]
batch_df = df_wbm[df_wbm.index.str.startswith(f"wbm-step-{batch_idx}-")]
assert 1e4 < len(batch_df) < 1e5, print(f"{len(batch_df) = :,}")

ax, metrics = hist_classified_stable_vs_hull_dist(
e_above_hull_pred=batch_df[pred_col] - batch_df.e_form_per_atom,
e_above_hull_true=batch_df.e_above_hull_mp,
e_above_hull_pred=batch_df[model_name] - batch_df[target_col],
e_above_hull_true=batch_df[e_above_hull_col],
which_energy=which_energy,
stability_crit=stability_crit,
ax=ax,
Expand All @@ -93,8 +56,8 @@


ax, metrics = hist_classified_stable_vs_hull_dist(
e_above_hull_pred=df[pred_col] - df.e_form_per_atom,
e_above_hull_true=df.e_above_hull_mp,
e_above_hull_pred=df_wbm[model_name] - df_wbm[target_col],
e_above_hull_true=df_wbm[e_above_hull_col],
which_energy=which_energy,
stability_crit=stability_crit,
ax=axs.flat[-1],
Expand All @@ -103,7 +66,7 @@
text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
ax.text(0.02, 0.3, text, fontsize=16, transform=ax.transAxes)

axs.flat[-1].set(title=f"All batches ({len(df.filter(like='e_').dropna()):,})")
axs.flat[-1].set(title=f"All batches ({len(df_wbm[model_name].dropna()):,})")
axs.flat[0].legend(frameon=False, loc="upper left")

fig.suptitle(f"{today} {model_name}", y=1.07, fontsize=16)
Expand All @@ -112,9 +75,3 @@
# %%
img_name = f"{today}-{model_name}-wbm-hull-dist-hist-batches"
ax.figure.savefig(f"{ROOT}/figures/{img_name}.pdf")


# %%
pymatviz.density_scatter(
df=dfs[model_name].query(f"{target_col} < 5"), x=target_col, y=pred_col
)
86 changes: 19 additions & 67 deletions matbench_discovery/plot_scripts/precision_recall.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,45 @@
# %%
import pandas as pd
from sklearn.metrics import f1_score

from matbench_discovery import ROOT, today
from matbench_discovery.plot_scripts import df_wbm
from matbench_discovery.plot_scripts import load_df_wbm_with_preds
from matbench_discovery.plots import StabilityCriterion, cumulative_clf_metric, plt

__author__ = "Rhys Goodall, Janosh Riebesell"


# %%
dfs: dict[str, pd.DataFrame] = {}
for model_name in ("wren", "cgcnn", "voronoi"):
csv_path = (
f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv"
)
df = pd.read_csv(csv_path).set_index("material_id")
dfs[model_name] = df

dfs["m3gnet"] = pd.read_json(
f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
).set_index("material_id")

dfs["wrenformer"] = pd.read_csv(
f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
).set_index("material_id")
models = (
"Wren, CGCNN IS2RE, CGCNN RS2RE, Voronoi IS2RE, Voronoi RS2RE, "
"Wrenformer, MEGNet"
).split(", ")

dfs["bowsr_megnet"] = pd.read_json(
f"{ROOT}/models/bowsr/2022-11-22-bowsr-megnet-wbm-IS2RE.json.gz"
).set_index("material_id")
df_wbm = load_df_wbm_with_preds(models=models).round(3)

print(f"loaded models: {list(dfs)}")
target_col = "e_form_per_atom_mp2020_corrected"
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"


# %%
stability_crit: StabilityCriterion = "energy"
colors = "tab:blue tab:orange teal tab:pink black red turquoise tab:purple".split()
F1s: dict[str, float] = {}

for model_name, df in sorted(dfs.items()):
if "std" in stability_crit:
# TODO column names to compute standard deviation from are currently hardcoded
# needs to be updated when adding non-aviary models with uncertainty estimation
var_aleatoric = (df.filter(regex=r"_ale_\d") ** 2).mean(axis=1)
var_epistemic = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
std_total = (var_epistemic + var_aleatoric) ** 0.5
else:
std_total = None

try:
if model_name == "m3gnet":
model_preds = df.e_form_m3gnet
elif "wrenformer" in model_name:
model_preds = df.e_form_per_atom_pred_ens
elif len(pred_cols := df.filter(like="e_form_pred").columns) >= 1:
# Voronoi+RF has single prediction column, Wren and CGCNN each have 10
# other cases are unexpected
assert len(pred_cols) in (1, 10), f"{model_name=} has {len(pred_cols)=}"
model_preds = df[pred_cols].mean(axis=1)
elif model_name == "bowsr_megnet":
model_preds = df.e_form_per_atom_bowsr_megnet
else:
raise ValueError(f"Unhandled {model_name = }")
except AttributeError as exc:
raise KeyError(f"{model_name = }") from exc

df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
df["e_form_per_atom"] = df_wbm.e_form_per_atom_mp2020_corrected
df["e_above_hull_pred"] = model_preds - df.e_form_per_atom
if n_nans := df.isna().values.sum() > 0:
assert n_nans < 10, f"{model_name=} has {n_nans=}"
df = df.dropna()

F1 = f1_score(df.e_above_hull_mp < 0, df.e_above_hull_pred < 0)
F1s[model_name] = F1


# %%
fig, (ax_prec, ax_recall) = plt.subplots(1, 2, figsize=(15, 7), sharey=True)

for (model_name, F1), color in zip(sorted(F1s.items(), key=lambda x: x[1]), colors):
df = dfs[model_name]
e_above_hull_error = df.e_above_hull_pred + df.e_above_hull_mp
e_above_hull_true = df.e_above_hull_mp
for model_name, color in zip(models, colors):

e_above_hull_pred = df_wbm[model_name] - df_wbm[target_col]

F1 = f1_score(df_wbm[e_above_hull_col] < 0, e_above_hull_pred < 0)

e_above_hull_error = e_above_hull_pred + df_wbm[e_above_hull_col]
cumulative_clf_metric(
e_above_hull_error,
e_above_hull_true,
df_wbm[e_above_hull_col],
color=color,
label=f"{model_name}\n{F1=:.2}",
label=f"{model_name}\n{F1=:.3}",
project_end_point="xy",
stability_crit=stability_crit,
ax=ax_prec,
Expand All @@ -96,9 +48,9 @@

cumulative_clf_metric(
e_above_hull_error,
e_above_hull_true,
df_wbm[e_above_hull_col],
color=color,
label=f"{model_name}\n{F1=:.2}",
label=f"{model_name}\n{F1=:.3}",
project_end_point="xy",
stability_crit=stability_crit,
ax=ax_recall,
Expand Down
11 changes: 0 additions & 11 deletions matbench_discovery/plot_scripts/rolling_mae_vs_hull_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@


# %%
markers = ["o", "v", "^", "H", "D", ""]

data_path = (
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
# f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
Expand All @@ -21,15 +19,6 @@


# %%
# rare = "all"
# from pymatgen.core import Composition
# rare = "no-lanthanides"
# df["contains_rare_earths"] = df.composition.map(
# lambda x: any(el.is_rare_earth_metal for el in Composition(x))
# )
# df = df.query("~contains_rare_earths")


df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp

assert all(n_nans := df.isna().sum() == 0), f"Found {n_nans} NaNs"
Expand Down
13 changes: 5 additions & 8 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
e_above_hull_mp="Energy above MP convex hull (eV/atom)",
e_above_hull_error="Error in energy above convex hull (eV/atom)",
vol_diff="Volume difference (A^3)",
e_form_per_atom_mp2020_corrected="Formation energy (eV/atom)",
e_form_per_atom_pred="Predicted formation energy (eV/atom)",
material_id="Material ID",
band_gap="Band gap (eV)",
formula="Formula",
)
model_labels = dict(
wren="Wren",
Expand Down Expand Up @@ -254,10 +259,6 @@ def rolling_mae_vs_hull_dist(
"""
ax = ax or plt.gca()

for series in (e_above_hull_pred, e_above_hull_true):
n_nans = series.isna().sum()
assert n_nans == 0, f"{n_nans:,} NaNs in {series.name}"

is_fresh_ax = len(ax.lines) == 0

bins = np.arange(*x_lim, bin_width)
Expand Down Expand Up @@ -387,10 +388,6 @@ def cumulative_clf_metric(
"""
ax = ax or plt.gca()

for series in (e_above_hull_error, e_above_hull_true):
n_nans = series.isna().sum()
assert n_nans == 0, f"{n_nans:,} NaNs in {series.name}"

e_above_hull_error = e_above_hull_error.sort_values()
e_above_hull_true = e_above_hull_true.loc[e_above_hull_error.index]

Expand Down
3 changes: 1 addition & 2 deletions models/m3gnet/eda_wbm_pre_vs_post_m3gnet_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@
df["m3gnet_vol_diff"] = df.m3gnet_volume - df.final_wbm_volume
df["dft_vol_diff"] = df.initial_wbm_volume - df.final_wbm_volume
fig = px.histogram(
pd.melt(
df,
df.melt(
value_vars=["m3gnet", "dft"],
value_name="vol_diff",
var_name="method",
Expand Down
2 changes: 1 addition & 1 deletion models/megnet/test_megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

"""
To slurm submit this file: python path/to/file.py slurm-submit
Requires Megnet installation: pip install megnet
Requires MEGNet installation: pip install megnet
https://github.com/materialsvirtuallab/megnet
"""

Expand Down

0 comments on commit 0236c36

Please sign in to comment.