Skip to content

Commit

Permalink
add mb_discovery/build_phase_diagram.py
Browse files Browse the repository at this point in the history
add missing MaterialsProject2020Compatibility processing to process_wbm_cleaned.py
  • Loading branch information
janosh committed Jun 20, 2023
1 parent e7f4582 commit 42a7909
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 73 deletions.
120 changes: 120 additions & 0 deletions mb_discovery/build_phase_diagram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# %%
import gzip
import json
import os
import pickle
from datetime import datetime

import pandas as pd
import pymatviz
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.ext.matproj import MPRester

from mb_discovery import ROOT
from mb_discovery.compute_formation_energy import (
get_elemental_ref_entries,
get_form_energy_per_atom,
)

today = f"{datetime.now():%Y-%m-%d}"
module_dir = os.path.dirname(__file__)


# %%
all_mp_computed_structure_entries = MPRester().get_entries("") # run on 2022-09-16

# save all ComputedStructureEntries to disk
pd.Series(
{e.entry_id: e for e in all_mp_computed_structure_entries}
).drop_duplicates().to_json( # mp-15590 appears twice so we drop_duplicates()
f"{ROOT}/data/{today}-all-mp-entries.json.gz", default_handler=lambda x: x.as_dict()
)


# %%
all_mp_computed_entries = (
pd.read_json(f"{ROOT}/data/2022-09-16-all-mp-entries.json.gz")
.set_index("material_id")
.entry.map(ComputedEntry.from_dict) # drop the structure, just load ComputedEntry
.to_dict()
)


print(f"{len(all_mp_computed_entries) = :,}")
# len(all_mp_computed_entries) = 146,323


# %% build phase diagram with MP entries only
ppd_mp = PatchedPhaseDiagram(all_mp_computed_entries)
# prints:
# PatchedPhaseDiagram
# Covering 44805 Sub-Spaces

# save MP PPD to disk
with gzip.open(f"{module_dir}/{today}-ppd-mp.pkl.gz", "wb") as zip_file:
pickle.dump(ppd_mp, zip_file)


# %% build phase diagram with both MP entries + WBM entries
df_wbm = pd.read_json(
f"{ROOT}/data/2022-06-26-wbm-cses-and-initial-structures.json.gz"
).set_index("material_id")

wbm_computed_entries: list[ComputedEntry] = df_wbm.query("n_elements > 1").cse.map(
ComputedEntry.from_dict
)

wbm_computed_entries = MaterialsProject2020Compatibility().process_entries(
wbm_computed_entries, verbose=True, clean=True
)

n_skipped = len(df_wbm) - len(wbm_computed_entries)
print(f"{n_skipped:,} ({n_skipped / len(df_wbm):.1%}) entries not processed")


# %% merge MP and WBM entries into a single PatchedPhaseDiagram
mp_wbm_ppd = PatchedPhaseDiagram(
wbm_computed_entries + all_mp_computed_entries, verbose=True
)


# %% compute terminal reference entries across all MP (can be used to compute MP
# compatible formation energies quickly)
elemental_ref_entries = get_elemental_ref_entries(all_mp_computed_entries)

# save elemental_ref_entries to disk as json
with open(f"{module_dir}/{today}-elemental-ref-entries.json", "w") as file:
json.dump(elemental_ref_entries, file, default=lambda x: x.as_dict())


# %% load MP elemental reference entries to compute formation energies
mp_elem_refs_path = f"{ROOT}/data/2022-09-19-mp-elemental-reference-entries.json"
mp_reference_entries = (
pd.read_json(mp_elem_refs_path, typ="series").map(ComputedEntry.from_dict).to_dict()
)


df_mp = pd.read_json(f"{ROOT}/data/2022-08-13-mp-all-energies.json.gz").set_index(
"material_id"
)


# %%
df_mp["our_mp_e_form"] = [
get_form_energy_per_atom(all_mp_computed_entries[mp_id], mp_reference_entries)
for mp_id in df_mp.index
]


# make sure get_form_energy_per_atom() reproduces MP formation energies
ax = pymatviz.density_scatter(
df_mp["formation_energy_per_atom"], df_mp["our_mp_e_form"]
)
ax.set(
title="MP Formation Energy Comparison",
xlabel="MP Formation Energy (eV/atom)",
ylabel="Our Formation Energy (eV/atom)",
)
ax.figure.savefig(f"{ROOT}/tmp/{today}-mp-formation-energy-comparison.png", dpi=300)
59 changes: 4 additions & 55 deletions mb_discovery/compute_formation_energy.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,12 @@
# %%
import gzip
import itertools
import json
import os
import pickle
from datetime import datetime

import pandas as pd
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram, PDEntry
from pymatgen.ext.matproj import MPRester
from pymatgen.analysis.phase_diagram import Entry
from tqdm import tqdm

from mb_discovery import ROOT

today = f"{datetime.now():%Y-%m-%d}"
module_dir = os.path.dirname(__file__)


# %%
def get_elemental_ref_entries(
entries: list[PDEntry], verbose: bool = False
) -> dict[str, PDEntry]:
entries: list[Entry], verbose: bool = False
) -> dict[str, Entry]:

elements = {elems for entry in entries for elems in entry.composition.elements}
dim = len(elements)
Expand Down Expand Up @@ -53,7 +39,7 @@ def get_elemental_ref_entries(


def get_form_energy_per_atom(
entry: PDEntry, elemental_ref_entries: dict[str, PDEntry]
entry: Entry, elemental_ref_entries: dict[str, Entry]
) -> float:
"""Get the formation energy of a composition from a list of entries and elemental
reference energies.
Expand All @@ -65,40 +51,3 @@ def get_form_energy_per_atom(
)

return form_energy / entry.composition.num_atoms


# %%
if __name__ == "__main__":
all_mp_entries = MPRester().get_entries("") # run on 2022-09-16
# mp-15590 appears twice so we drop_duplicates()
df_mp_entries = pd.DataFrame(all_mp_entries, columns=["entry"]).drop_duplicates()
df_mp_entries["material_id"] = [x.entry_id for x in df_mp_entries.entry]
df_mp_entries = df_mp_entries.set_index("material_id")

df_mp_entries.reset_index().to_json(
f"{ROOT}/data/{today}-2-all-mp-entries.json.gz",
default_handler=lambda x: x.as_dict(),
)

df_mp_entries = pd.read_json(
f"{ROOT}/data/2022-09-16-all-mp-entries.json.gz"
).set_index("material_id")
all_mp_entries = [PDEntry.from_dict(x) for x in df_mp_entries.entry]

print(f"{len(df_mp_entries) = :,}")
# len(df_mp_entries) = 146,323

ppd_mp = PatchedPhaseDiagram(all_mp_entries)
# prints:
# PatchedPhaseDiagram
# Covering 44805 Sub-Spaces

# save MP PPD to disk
with gzip.open(f"{module_dir}/{today}-ppd-mp.pkl.gz", "wb") as zip_file:
pickle.dump(ppd_mp, zip_file)

elemental_ref_entries = get_elemental_ref_entries(all_mp_entries)

# save elemental_ref_entries to disk as json
with open(f"{module_dir}/{today}-elemental-ref-entries.json", "w") as f:
json.dump(elemental_ref_entries, f, default=lambda x: x.as_dict())
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

# download wbm-steps-summary.csv (23.31 MB)
df_summary = pd.read_csv(
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
).set_index("material_id")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

# download wbm-steps-summary.csv (23.31 MB)
df_summary = pd.read_csv(
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
).set_index("material_id")


Expand Down
16 changes: 9 additions & 7 deletions mb_discovery/plot_scripts/precision_recall_vs_calc_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

# %% download wbm-steps-summary.csv (23.31 MB)
df_wbm = pd.read_csv(
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
).set_index("material_id")


Expand Down Expand Up @@ -69,23 +69,24 @@
try:
if model_name == "M3GNet":
model_preds = df.e_form_m3gnet
targets = df.e_form_wbm
elif "Wrenformer" in model_name:
model_preds = df.e_form_per_atom_pred_ens
targets = df.e_form_per_atom
elif len(pred_cols := df.filter(like="e_form_pred").columns) >= 1:
# Voronoi+RF has single prediction column, Wren and CGCNN each have 10
# other cases are unexpected
assert len(pred_cols) in (1, 10), f"{model_name=} has {len(pred_cols)=}"
model_preds = df[pred_cols].mean(axis=1)
targets = df.e_form_target
else:
raise ValueError(f"Unhandled {model_name = }")
except AttributeError as exc:
raise KeyError(f"{model_name = }") from exc

df["e_above_mp_hull"] = df_hull.e_above_mp_hull
df["e_above_hull_pred"] = model_preds - targets
df["e_form_per_atom"] = df_wbm.e_form_per_atom
df["e_above_hull_pred"] = model_preds - df.e_form_per_atom
if n_nans := df.isna().values.sum() > 0:
assert n_nans < 10, f"{model_name=} has {n_nans=}"
df = df.dropna()

ax = precision_recall_vs_calc_count(
e_above_hull_error=df.e_above_hull_pred + df.e_above_mp_hull,
Expand All @@ -97,9 +98,10 @@
std_pred=std_total,
)

ax.legend(frameon=False, loc="lower right")

ax.figure.set_size_inches(10, 9)
ax.set(xlim=(0, None))
# keep this outside loop so all model names appear in legend
ax.legend(frameon=False, loc="lower right")

img_path = f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-{rare=}.pdf"
if False:
Expand Down
9 changes: 3 additions & 6 deletions mb_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,9 @@ def precision_recall_vs_calc_count(
# previous call
return ax

ax.set(
xlabel="Number of compounds sorted by model-predicted hull distance",
ylabel="Precision and Recall (%)",
)

ax.set(ylim=(0, 100))
xlabel = "Number of compounds sorted by model-predicted hull distance"
ylabel = "Precision and Recall (%)"
ax.set(ylim=(0, 100), xlabel=xlabel, ylabel=ylabel)

[precision] = ax.plot((0, 0), (0, 0), "black", linestyle="-")
[recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":")
Expand Down
2 changes: 1 addition & 1 deletion models/bowsr/join_bowsr_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

# %%
df_wbm = pd.read_csv( # download wbm-steps-summary.csv (23.31 MB)
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
).set_index("material_id")

df_bowsr["e_form_wbm"] = df_wbm.e_form_per_atom
Expand Down
2 changes: 1 addition & 1 deletion models/m3gnet/join_m3gnet_relax_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@


df_wbm = pd.read_csv( # download wbm-steps-summary.csv (23.31 MB)
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
).set_index("material_id")

df_m3gnet["e_form_wbm"] = df_wbm.e_form_per_atom
Expand Down
2 changes: 1 addition & 1 deletion models/wrenformer/mp/use_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# %%
# download wbm-steps-summary.csv (23.31 MB)
data_path = "https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
data_path = "https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
df = pd.read_csv(data_path).set_index("material_id")


Expand Down

0 comments on commit 42a7909

Please sign in to comment.