Skip to content

Commit

Permalink
split analyze_model_failure_cases.py into two scripts, new one is ana…
Browse files Browse the repository at this point in the history
…lyze_elements.py

update bunch of figures
add bar-element-counts-mp+wbm to with/without normalization to /about-the-data/tmi
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 0f2410d commit 0fad3bd
Show file tree
Hide file tree
Showing 21 changed files with 439 additions and 302 deletions.
1 change: 1 addition & 0 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
e_form_col = "e_form_per_atom_mp2020_corrected"
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
each_pred_col = "e_above_hull_pred"
model_mean_err_col = "Mean over models"


class PredFiles(Files):
Expand Down
4 changes: 3 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ In version 1 of this benchmark, we explore 8 models covering multiple methodolog

We welcome contributions that add new models to the leaderboard through [GitHub PRs](https://github.com/janosh/matbench-discovery/pulls). See the [usage and contributing guide](https://janosh.github.io/matbench-discovery/contribute) for details.

For a version 2 release of this benchmark, we plan to merge the current training and test sets into the new training set and acquire a much larger test set compared to the v1 test set of 257k structures.
For a version 2 release of this benchmark, we plan to merge the current training and test sets into the new training set and acquire a much larger test set (potentially at meta-GGA level of theory) compared to the v1 test set of 257k structures. Anyone interested in joining this effort please [open a GitHub discussion](https://github.com/janosh/matbench-discovery/discussions) or [reach out privately](mailto:janosh@lbl.gov?subject=Matbench%20Discovery).

For detailed results and analysis, check out the [paper](https://matbench-discovery.janosh.dev/paper) and [supplementary material](https://matbench-discovery.janosh.dev/si).
246 changes: 246 additions & 0 deletions scripts/analyze_element_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
"""Analyze structures and composition with largest mean error across all models.
Maybe there's some chemistry/region of materials space that all models struggle with?
Might point to deficiencies in the data or models architecture.
"""


# %%
import pandas as pd
import plotly.express as px
from pymatgen.core import Composition, Element
from pymatviz import count_elements, ptable_heatmap_plotly
from pymatviz.utils import bin_df_cols, save_fig
from sklearn.metrics import r2_score
from tqdm import tqdm

from matbench_discovery import FIGS, MODELS, ROOT
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.preds import (
df_each_err,
df_metrics,
df_preds,
each_pred_col,
each_true_col,
model_mean_err_col,
)

__author__ = "Janosh Riebesell"
__date__ = "2023-02-15"

df_each_err[model_mean_err_col] = df_preds[model_mean_err_col] = df_each_err.abs().mean(
axis=1
)


# %%
df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id")
# compute number of samples per element in training set
# counting element occurrences not weighted by composition, assuming model don't learn
# much more about iron and oxygen from Fe2O3 than from FeO

train_count_col = "MP Occurrences"
df_elem_err = count_elements(df_mp.formula_pretty, count_mode="occurrence").to_frame(
name=train_count_col
)


# %%
fig = ptable_heatmap_plotly(df_elem_err[train_count_col], font_size=10)
title = "Number of MP structures containing each element"
fig.layout.title.update(text=title, x=0.4, y=0.9)
fig.show()


# %% map average model error onto elements
frac_comp_col = "fractional composition"
df_wbm[frac_comp_col] = [
Composition(comp).fractional_composition for comp in tqdm(df_wbm.formula)
]

df_frac_comp = pd.DataFrame(comp.as_dict() for comp in df_wbm[frac_comp_col]).set_index(
df_wbm.index
)
assert all(
df_frac_comp.sum(axis=1).round(6) == 1
), "composition fractions don't sum to 1"

# df_frac_comp = df_frac_comp.dropna(axis=1, thresh=100) # remove Xe with only 1 entry


# %%
for label, srs in (
("MP", df_elem_err[train_count_col]),
("WBM", df_frac_comp.where(pd.isna, 1).sum()),
):
title = f"Number of {label} structures containing each element"
srs = srs.sort_values().copy()
srs.index = [f"{len(srs) - idx} {el}" for idx, el in enumerate(srs.index)]
fig = srs.plot.bar(backend="plotly", title=title)
fig.layout.update(showlegend=False)
fig.show()


# %% plot structure counts for each element in MP and WBM in a grouped bar chart
df_struct_counts = pd.DataFrame(index=df_elem_err.index)
df_struct_counts["MP"] = df_elem_err[train_count_col]
df_struct_counts["WBM"] = df_frac_comp.where(pd.isna, 1).sum()
min_count = 10 # only show elements with at least 10 structures
df_struct_counts = df_struct_counts[df_struct_counts.sum(axis=1) > min_count]
normalized = False
if normalized:
df_struct_counts["MP"] /= len(df_mp) / 100
df_struct_counts["WBM"] /= len(df_wbm) / 100
y_col = "percent" if normalized else "count"
fig = (
df_struct_counts.reset_index()
.melt(var_name="dataset", value_name=y_col, id_vars="symbol")
.sort_values([y_col, "symbol"])
.plot.bar(
x="symbol",
y=y_col,
backend="plotly",
title="Number of structures containing each element",
color="dataset",
barmode="group",
)
)

fig.layout.update(bargap=0.1)
fig.layout.legend.update(x=0.02, y=0.98, font_size=16)
fig.show()
save_fig(fig, f"{FIGS}/bar-element-counts-mp+wbm-{normalized=}.svelte")


# %%
test_set_std_col = "Test set standard deviation (eV/atom)"
df_elem_err[test_set_std_col] = (
df_frac_comp.where(pd.isna, 1) * df_wbm[each_true_col].values[:, None]
).std()


# %%
fig = ptable_heatmap_plotly(
df_elem_err[test_set_std_col], precision=".2f", colorscale="Inferno"
)
fig.show()


# %%
normalized = True
cs_range = (0, 0.5) # same range for all plots
# cs_range = (None, None) # different range for each plot
for model in (*df_metrics, model_mean_err_col):
df_elem_err[model] = (
df_frac_comp * df_each_err[model].abs().values[:, None]
).mean()
# don't change series values in place, would change the df
per_elem_err = df_elem_err[model].copy(deep=True)
per_elem_err.name = f"{model} (eV/atom)"
if normalized:
per_elem_err /= df_elem_err[test_set_std_col]
per_elem_err.name = f"{model} (normalized by test set std)"
fig = ptable_heatmap_plotly(
per_elem_err, precision=".2f", colorscale="Inferno", cscale_range=cs_range
)
fig.show()


# %%
assert (df_elem_err.isna().sum() < 35).all()
df_elem_err.round(4).to_json(f"{MODELS}/per-element-model-each-errors.json")


# %% scatter plot error by element against prevalence in training set
# for checking correlation and R2 of elemental prevalence in MP training data vs.
# model error
df_elem_err["elem_name"] = [Element(el).long_name for el in df_elem_err.index]
R2 = r2_score(*df_elem_err[[train_count_col, model_mean_err_col]].dropna().values.T)
r_P = df_elem_err[model_mean_err_col].corr(df_elem_err[train_count_col])

fig = df_elem_err.plot.scatter(
x=train_count_col,
y=model_mean_err_col,
backend="plotly",
hover_name="elem_name",
text=df_elem_err.index.where(
(df_elem_err[model_mean_err_col] > 0.04)
| (df_elem_err[train_count_col] > 6_000)
),
title="Per-element error vs element-occurrence in MP training "
f"set: r<sub>Pearson</sub>={r_P:.2f}, R<sup>2</sup>={R2:.2f}",
hover_data={model_mean_err_col: ":.2f", train_count_col: ":,.0f"},
)
fig.update_traces(textposition="top center") # place text above scatter points
fig.layout.title.update(xanchor="center", x=0.5)
fig.show()

# save_fig(fig, f"{FIGS}/element-prevalence-vs-error.svelte")
save_fig(fig, f"{ROOT}/tmp/figures/element-prevalence-vs-error.pdf")


# %% plot EACH errors against least prevalent element in structure (by occurrence in
# MP training set). this seems to correlate more with model error
n_examp_for_rarest_elem_col = "Examples for rarest element in structure"
df_wbm["composition"] = df_wbm.get("composition", df_wbm.formula.map(Composition))
df_elem_err.loc[list(map(str, df_wbm.composition[0]))][train_count_col].min()
df_wbm[n_examp_for_rarest_elem_col] = [
df_elem_err.loc[list(map(str, Composition(formula)))][train_count_col].min()
for formula in tqdm(df_wbm.formula)
]


# %%
df_melt = (
df_each_err.abs()
.reset_index()
.melt(var_name="Model", value_name=each_pred_col, id_vars="material_id")
.set_index("material_id")
)
df_melt[n_examp_for_rarest_elem_col] = df_wbm[n_examp_for_rarest_elem_col]

df_bin = bin_df_cols(df_melt, [n_examp_for_rarest_elem_col, each_pred_col], ["Model"])
df_bin = df_bin.reset_index().set_index("material_id")
df_bin["formula"] = df_wbm.formula


# %%
fig = px.scatter(
df_bin.reset_index(),
x=n_examp_for_rarest_elem_col,
y=each_pred_col,
color="Model",
facet_col="Model",
facet_col_wrap=3,
hover_data=dict(material_id=True, formula=True, Model=False),
title="Absolute errors in model-predicted E<sub>above hull</sub> vs. occurrence "
"count in MP training set<br>of least prevalent element in structure",
)
fig.layout.update(showlegend=False)
fig.layout.title.update(x=0.5, xanchor="center", y=0.95)
fig.layout.margin.update(t=100)
# remove axis labels
fig.update_xaxes(title="")
fig.update_yaxes(title="")
for anno in fig.layout.annotations:
anno.text = anno.text.split("=")[1]

fig.add_annotation(
text="MP occurrence count of least prevalent element in structure",
x=0.5,
y=-0.18,
xref="paper",
yref="paper",
showarrow=False,
)
fig.add_annotation(
text="Absolute error in E<sub>above hull</sub>",
x=-0.07,
y=0.5,
xref="paper",
yref="paper",
showarrow=False,
textangle=-90,
)

fig.show()
save_fig(fig, f"{FIGS}/each-error-vs-least-prevalent-element-in-struct.svelte")
Loading

0 comments on commit 0fad3bd

Please sign in to comment.