Skip to content

Commit

Permalink
fix WBM e_above_convex_hull values w.r.t. MP PPD
Browse files Browse the repository at this point in the history
(were wrong due to PPD constructed from PDEntries without energy corrections)
drop calculation of MP-legacy corrected e_form and e_above_hull from fetch_process_wbm_dataset.py
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 4013bb1 commit 5dae951
Show file tree
Hide file tree
Showing 14 changed files with 312 additions and 360 deletions.
31 changes: 19 additions & 12 deletions data/mp/build_phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,46 @@
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
from pymatgen.ext.matproj import MPRester
from tqdm import tqdm

from matbench_discovery import ROOT, today
from matbench_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries

module_dir = os.path.dirname(__file__)


# %%
all_mp_computed_structure_entries = MPRester().get_entries("") # run on 2022-09-16
# %% run on 2022-09-16 and again on 2023-02-07
all_mp_computed_structure_entries = MPRester().get_entries("")

# save all ComputedStructureEntries to disk
pd.Series(
{e.entry_id: e for e in all_mp_computed_structure_entries}
).drop_duplicates().to_json( # mp-15590 appears twice so we drop_duplicates()
# mp-15590 appears twice so we drop_duplicates()
df = pd.DataFrame(all_mp_computed_structure_entries, columns=["entry"])
df.index.name = "material_id"
df.index = [e.entry_id for e in df.entry]
df.reset_index().to_json(
f"{module_dir}/{today}-mp-computed-structure-entries.json.gz",
default_handler=lambda x: x.as_dict(),
)


# %%
data_path = f"{module_dir}/2022-09-16-mp-computed-structure-entries.json.gz"
data_path = f"{module_dir}/2023-02-07-mp-computed-structure-entries.json.gz"
df = pd.read_json(data_path).set_index("material_id")
# drop the structure, just load ComputedEntry
mp_computed_entries = df.entry.map(ComputedEntry.from_dict).to_dict()

print(f"{len(mp_computed_entries) = :,}")
# len(mp_computed_entries) = 146,323
# drop the structure, just load ComputedEntry, makes the PPD faster to build and load
mp_computed_entries = [ComputedEntry.from_dict(x) for x in tqdm(df.entry)]

print(f"{len(mp_computed_entries) = :,} on {today}")
# len(mp_computed_entries) = 146,323 on 2022-09-16
# len(mp_computed_entries) = 154,719 on 2023-02-07


# %% build phase diagram with MP entries only
ppd_mp = PatchedPhaseDiagram(mp_computed_entries)
ppd_mp = PatchedPhaseDiagram(mp_computed_entries, verbose=True)
print(f"{ppd_mp} on {today}")
# prints:
# PatchedPhaseDiagram covering 44805 sub-spaces
# PatchedPhaseDiagram covering 44805 sub-spaces on 2022-09-16
# PatchedPhaseDiagram covering 46216 sub-spaces on 2023-02-07

# save MP PPD to disk
with gzip.open(f"{module_dir}/{today}-ppd-mp.pkl.gz", "wb") as zip_file:
Expand Down
11 changes: 4 additions & 7 deletions data/wbm/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pymatviz import count_elements, ptable_heatmap_plotly
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, today
from matbench_discovery import FIGS, ROOT, today
from matbench_discovery.data import df_wbm
from matbench_discovery.energy import mp_elem_reference_entries
from matbench_discovery.plots import pio
Expand All @@ -21,9 +21,8 @@
# %%
wbm_elem_counts = count_elements(df_wbm.formula).astype(int)

# wbm_elem_counts.to_json(
# f"{ROOT}/site/src/routes/about-the-test-set/{today}-wbm-element-counts.json"
# )
out_elem_counts = f"{ROOT}/site/src/routes/about-the-test-set/wbm-element-counts.json"
# wbm_elem_counts.to_json(out_elem_counts)


# %%
Expand All @@ -46,9 +45,7 @@


# %%
wbm_fig.write_image(
f"{module_dir}/figs/{today}-wbm-elements.svg", width=1000, height=500
)
wbm_fig.write_image(f"{module_dir}/figs/wbm-elements.svg", width=1000, height=500)
save_fig(wbm_fig, f"{FIGS}/{today}-wbm-elements.svelte")


Expand Down
89 changes: 22 additions & 67 deletions data/wbm/fetch_process_wbm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@
from aviary.wren.utils import get_aflow_label_from_spglib
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram
from pymatgen.core import Composition, Structure
from pymatgen.entries.compatibility import (
MaterialsProject2020Compatibility as MP2020Compat,
)
from pymatgen.entries.compatibility import (
MaterialsProjectCompatibility as MPLegacyCompat,
)
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatviz import density_scatter
from pymatviz.utils import save_fig
Expand Down Expand Up @@ -291,7 +286,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
"vol": "volume",
"e": "uncorrected_energy",
"e_form": "e_form_per_atom_wbm",
"e_hull": "e_hull_wbm",
"e_hull": "e_above_hull_wbm",
"gap": "bandgap_pbe",
"id": "material_id",
}
Expand Down Expand Up @@ -361,7 +356,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
for mat_id, cse in df_wbm.computed_structure_entry.items():
entry_id = cse["entry_id"]
if mat_id != entry_id:
print(f"{mat_id=} != {entry_id=}")
print(f"{mat_id=} != {entry_id=}, updating entry_id to mat_id")
cse["entry_id"] = mat_id


Expand Down Expand Up @@ -499,73 +494,39 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
# summary and CSE n_sites match
assert all(df_summary.n_sites == [len(cse.structure) for cse in df_wbm.cse])

for mp_compat in [MPLegacyCompat(), MP2020Compat()]:
compat_out = mp_compat.process_entries(df_wbm.cse, clean=True, verbose=True)
assert len(compat_out) == len(df_wbm) == len(df_summary)

n_corrected = sum(cse.uncorrected_energy != cse.energy for cse in df_wbm.cse)
if isinstance(mp_compat, MPLegacyCompat):
assert n_corrected == 39591, f"{n_corrected=}"
if isinstance(mp_compat, MP2020Compat):
assert n_corrected == 100930, f"{n_corrected=}"

corr_label = "mp2020" if isinstance(mp_compat, MP2020Compat) else "legacy"
df_summary[f"e_correction_per_atom_{corr_label}"] = [
cse.correction_per_atom for cse in df_wbm.cse
]

assert df_summary.e_correction_per_atom_mp2020.mean().round(4) == -0.1069
assert df_summary.e_correction_per_atom_legacy.mean().round(4) == -0.0645
assert (df_summary.filter(like="correction").abs() > 1e-4).sum().to_dict() == {
"e_correction_per_atom_mp2020": 100930,
"e_correction_per_atom_legacy": 39591,
}, "unexpected number of materials received non-zero corrections"

ax = density_scatter(
df_summary.e_correction_per_atom_legacy,
df_summary.e_correction_per_atom_mp2020,
xlabel="legacy corrections (eV / atom)",
ylabel="MP2020 corrections (eV / atom)",
compat_out = MaterialsProject2020Compatibility().process_entries(
entries=df_wbm.cse, clean=True, verbose=True
)
# ax.figure.savefig(f"{ROOT}/tmp/{today}-legacy-vs-mp2020-corrections.webp")
assert len(compat_out) == len(df_wbm) == len(df_summary)

n_corrected = sum(cse.uncorrected_energy != cse.energy for cse in df_wbm.cse)
assert n_corrected == 100_930, f"{n_corrected=} expected 100,930"

e_correction_col = "e_correction_per_atom_mp2020"
df_summary[e_correction_col] = [cse.correction_per_atom for cse in df_wbm.cse]

# %% Python crashes with segfault on correcting the energy of wbm-1-24459 due to
# https://github.com/spglib/spglib/issues/194 when using spglib versions 2.0.0 or 2.0.1
# left here as a reminder and for future users in case they encounter the same issue
cse = df_wbm.computed_structure_entry["wbm-1-24459"]
cse = ComputedStructureEntry.from_dict(cse)
mp_compat.process_entry(cse)
assert df_summary.e_correction_per_atom_mp2020.mean().round(4) == -0.1069


# %%
with gzip.open(f"{ROOT}/data/mp/2022-09-18-ppd-mp.pkl.gz", "rb") as zip_file:
with gzip.open(f"{ROOT}/data/mp/2023-02-07-ppd-mp.pkl.gz", "rb") as zip_file:
ppd_mp: PatchedPhaseDiagram = pickle.load(zip_file)


# %% calculate e_above_hull for each material
# this loop needs above warnings.filterwarnings() to not crash Jupyter kernel with logs
# takes ~20 min at 200 it/s for 250k entries in WBM
e_above_hull_key = "e_above_hull_uncorrected_ppd_mp"
assert e_above_hull_key not in df_summary
each_col = "e_above_hull_mp2020_corrected_ppd_mp"
assert each_col not in df_summary

for mat_id, entry in tqdm(df_wbm.cse.items(), total=len(df_wbm)):
assert mat_id == entry.entry_id, f"{mat_id=} != {entry.entry_id=}"
assert entry.entry_id in df_summary.index, f"{entry.entry_id=} not in df_summary"

e_per_atom = entry.uncorrected_energy_per_atom
e_hull_per_atom = ppd_mp.get_hull_energy_per_atom(entry.composition)
e_above_hull = e_per_atom - e_hull_per_atom

df_summary.at[entry.entry_id, e_above_hull_key] = e_above_hull
for mat_id, cse in tqdm(df_wbm.cse.items(), total=len(df_wbm)):
assert mat_id == cse.entry_id, f"{mat_id=} != {cse.entry_id=}"
assert cse.entry_id in df_summary.index, f"{cse.entry_id=} not in df_summary"

e_above_hull = ppd_mp.get_e_above_hull(cse, allow_negative=True)

# add old + new MP energy corrections to above hull energies
for corrections in ("mp2020", "legacy"):
df_summary[e_above_hull_key.replace("un", f"{corrections}_")] = (
df_summary[e_above_hull_key]
+ df_summary[f"e_correction_per_atom_{corrections}"]
)
df_summary.at[cse.entry_id, each_col] = e_above_hull


# %% calculate formation energies from CSEs wrt MP elemental reference energies
Expand All @@ -588,16 +549,10 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
# make sure the PPD.get_e_form_per_atom() and standalone get_e_form_per_atom()
# method of calculating formation energy agree
assert (
abs(e_form - (e_form_ppd - correction)) < 1e-7
), f"{mat_id=}: {e_form=:.3} != {e_form_ppd - correction=:.3}"
abs(e_form - (e_form_ppd - correction)) < 1e-4
), f"{mat_id=}: {e_form=:.5} != {e_form_ppd - correction=:.5}"
df_summary.at[cse.entry_id, e_form_col] = e_form

# add old + new MP energy corrections to formation energies
for corrections in ("mp2020", "legacy"):
df_summary[e_form_col.replace("un", f"{corrections}_")] = (
df_summary[e_form_col] + df_summary[f"e_correction_per_atom_{corrections}"]
)


# %%
df_init_struct = pd.read_json(
Expand Down
2 changes: 1 addition & 1 deletion data/wbm/readme.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# WBM Dataset

The **WBM dataset** was published in [Predicting stable crystalline compounds using chemical similarity][wbm paper] (Nature Computational Materials, Jan 2021, [doi:10.1038/s41524-020-00481-6](http://doi.org/10.1038/s41524-020-00481-6)). The authors generated 257,487 structures through single-element substitutions on Materials Project (MP) source structures. The replacement element was chosen based on chemical similarity determined by a matrix data mined from the [Inorganic Crystal Structure Database (ICSD)](https://icsd.products.fiz-karlsruhe.de).
The **WBM dataset** was published in [Predicting stable crystalline compounds using chemical similarity][wbm paper] (nat comp mat, Jan 2021). The authors generated 257,487 structures through single-element substitutions on Materials Project (MP) source structures. The replacement element was chosen based on chemical similarity determined by a matrix data-mined from the [Inorganic Crystal Structure Database (ICSD)](https://icsd.products.fiz-karlsruhe.de).

The resulting novel structures were relaxed using MP-compatible VASP inputs (i.e. using `pymatgen`'s `MPRelaxSet`) and identical POTCARs in an attempt to create a database of Materials Project compatible novel crystals. Any degradation in model performance from training to test set should therefore largely be a result of extrapolation error rather than covariate shift in the underlying data.

Expand Down
4 changes: 2 additions & 2 deletions matbench_discovery/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ def stable_metrics(
each_true, each_pred = np.array(each_true)[~is_nan], np.array(each_pred)[~is_nan]

return dict(
F1=2 * (precision * recall) / (precision + recall),
R2=r2_score(each_true, each_pred),
DAF=precision / prevalence,
Precision=precision,
Recall=recall,
Accuracy=(n_true_pos + n_true_neg) / len(each_true),
F1=2 * (precision * recall) / (precision + recall),
TPR=n_true_pos / n_total_pos,
FPR=n_false_pos / n_total_neg,
TNR=n_true_neg / n_total_neg,
FNR=n_false_neg / n_total_pos,
MAE=np.abs(each_true - each_pred).mean(),
RMSE=((each_true - each_pred) ** 2).mean() ** 0.5,
R2=r2_score(each_true, each_pred),
)
4 changes: 2 additions & 2 deletions scripts/compile_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pymatviz.utils import save_fig
from tqdm import tqdm

from matbench_discovery import FIGS, WANDB_PATH, today
from matbench_discovery import FIGS, ROOT, WANDB_PATH, today
from matbench_discovery.data import PRED_FILENAMES
from matbench_discovery.plots import px
from matbench_discovery.preds import df_metrics, df_wbm
Expand Down Expand Up @@ -154,7 +154,7 @@
with open(f"{FIGS}/metrics-table.svelte", "w") as file:
file.write(html)

dfi.export(styler, "model-metrics.png", dpi=300)
dfi.export(styler, f"{ROOT}/tmp/figures/model-metrics.png", dpi=300)


# %% write model metrics to json for use by the website
Expand Down
8 changes: 4 additions & 4 deletions scripts/cumulative_clf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, ROOT, STATIC
from matbench_discovery import FIGS, ROOT
from matbench_discovery.plots import cumulative_precision_recall
from matbench_discovery.preds import df_each_pred, df_metrics, df_wbm, each_true_col

Expand All @@ -27,7 +27,7 @@
fig.text(0.5, -0.08, xlabel, ha="center", fontdict={"size": 16})
if backend == "plotly":
fig.layout.legend.update(
x=0.02, y=0.02, itemsizing="constant", bgcolor="rgba(0,0,0,0)"
x=0.98, xanchor="right", y=0.02, itemsizing="constant", bgcolor="rgba(0,0,0,0)"
) # , title=title
# fig.layout.height = 500
fig.layout.margin = dict(l=0, r=5, t=30, b=60)
Expand Down Expand Up @@ -91,5 +91,5 @@

img_name = "cumulative-clf-metrics"
save_fig(fig, f"{FIGS}/{img_name}.svelte")
save_fig(fig, f"{STATIC}/{img_name}.webp", scale=3)
save_fig(fig, f"{ROOT}/tmp/figures/{img_name}.pdf", width=700, height=350)
# save_fig(fig, f"{STATIC}/{img_name}.webp", scale=3)
save_fig(fig, f"{ROOT}/tmp/figures/{img_name}.pdf", width=650, height=350)
6 changes: 3 additions & 3 deletions scripts/rolling_mae_vs_hull_dist_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pymatviz.utils import save_fig

from matbench_discovery import FIGS, ROOT, STATIC
from matbench_discovery import FIGS, ROOT
from matbench_discovery.plots import rolling_mae_vs_hull_dist
from matbench_discovery.preds import df_each_pred, df_metrics, df_wbm, each_true_col

Expand Down Expand Up @@ -48,12 +48,12 @@

# increase legend handle size and reverse order
fig.layout.margin = dict(l=5, r=5, t=5, b=55)
fig.layout.legend.update(itemsizing="constant")
fig.layout.legend.update(itemsizing="constant", bgcolor="rgba(0,0,0,0)")
fig.show()


# %%
img_name = "rolling-mae-vs-hull-dist-models"
save_fig(fig, f"{FIGS}/{img_name}.svelte")
save_fig(fig, f"{STATIC}/{img_name}.webp", scale=3)
# save_fig(fig, f"{STATIC}/{img_name}.webp", scale=3)
save_fig(fig, f"{ROOT}/tmp/figures/{img_name}.pdf")
Loading

0 comments on commit 5dae951

Please sign in to comment.