Skip to content

Commit

Permalink
split model run times into train and test contribs
Browse files Browse the repository at this point in the history
plot as table, pie, sunburst, bar charts
rename matbench_discovery.preds.df_wbm to df_preds
  • Loading branch information
janosh committed Jun 20, 2023
1 parent fa1a439 commit b8a18d8
Show file tree
Hide file tree
Showing 27 changed files with 267 additions and 228 deletions.
29 changes: 11 additions & 18 deletions data/mp/get_mp_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os

import pandas as pd
from aviary.utils import as_dict_handler
from aviary.wren.utils import get_aflow_label_from_spglib
from mp_api.client import MPRester
from pymatviz.utils import annotate_mae_r2
Expand All @@ -12,66 +11,60 @@
from matbench_discovery.data import DATA_FILES

"""
Download all MP formation and above hull energies on 2022-08-13.
Download all MP formation and above hull energies on 2023-01-10.
Related EDA of MP formation energies:
https://github.com/janosh/pymatviz/blob/main/examples/mp_bimodal_e_form.ipynb
"""

__author__ = "Janosh Riebesell"
__date__ = "2022-08-13"
__date__ = "2023-01-10"

module_dir = os.path.dirname(__file__)


# %% query all MP formation energies on 2022-08-13
fields = [
# %%
fields = {
"material_id",
"task_ids",
"formula_pretty",
"formation_energy_per_atom",
"energy_per_atom",
"structure",
"symmetry",
"energy_above_hull",
"decomposition_enthalpy",
"energy_type",
]
}

with MPRester(use_document_model=False) as mpr:
docs = mpr.thermo.search(fields=fields, thermo_types=["GGA_GGA+U"])

assert fields == set(docs[0]), f"missing fields: {fields - set(docs[0])}"
print(f"{today}: {len(docs) = :,}")
# 2022-08-13: len(docs) = 146,323
# 2023-01-10: len(docs) = 154,718


# %%
df = pd.DataFrame(docs).set_index("material_id")
df.pop("_id")

df.energy_type.value_counts().plot.pie(backend="matplotlib", autopct="%1.1f%%")
df.energy_type.value_counts().plot.pie(backend="plotly", autopct="%1.1f%%")
# GGA: 72.2%, GGA+U: 27.8%


# %%
df["spacegroup_number"] = df.pop("symmetry").map(lambda x: x["number"])
df["spacegroup_number"] = [x["number"] for x in df.pop("symmetry")]

df["wyckoff_spglib"] = [get_aflow_label_from_spglib(x) for x in tqdm(df.structure)]

df.reset_index().to_json(
f"{module_dir}/mp-energies.json.gz", default_handler=as_dict_handler
)

# read stored data back from disk
df = pd.read_json(DATA_FILES.mp_energies)
df.to_csv(DATA_FILES.mp_energies)
# df = pd.read_csv(DATA_FILES.mp_energies)


# %% reproduce fig. 1b from https://arxiv.org/abs/2001.10591 (as data consistency check)
ax = df.plot.scatter(
x="formation_energy_per_atom",
y="decomposition_enthalpy",
alpha=0.1,
backend="matplotlib",
xlim=[-5, 1],
ylim=[-1, 1],
color=(df.decomposition_enthalpy > 0).map({True: "red", False: "blue"}),
Expand Down
2 changes: 1 addition & 1 deletion matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class DataFiles(Files):
"mp/2023-02-07-mp-computed-structure-entries.json.gz"
)
mp_elemental_ref_entries = "mp/2022-09-19-mp-elemental-reference-entries.json"
mp_energies = "mp/2023-01-10-mp-energies.json.gz"
mp_energies = "mp/2023-01-10-mp-energies.csv"
mp_patched_phase_diagram = "mp/2023-02-07-ppd-mp.pkl.gz"
wbm_computed_structure_entries = (
"wbm/2022-10-19-wbm-computed-structure-entries.json.bz2"
Expand Down
18 changes: 11 additions & 7 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PredFiles(Files):
PRED_FILES = PredFiles()


def load_df_wbm_preds(
def load_df_wbm_with_preds(
models: Sequence[str] = (*PRED_FILES,),
pbar: bool = True,
id_col: str = "material_id",
Expand Down Expand Up @@ -111,15 +111,17 @@ def load_df_wbm_preds(
return df_out


df_wbm = load_df_wbm_preds().round(3)
df_preds = load_df_wbm_with_preds().round(3)
for combo in [["CHGNet", "M3GNet"]]:
df_preds[" + ".join(combo)] = df_preds[combo].mean(axis=1)


df_metrics = pd.DataFrame()
df_metrics.index.name = "model"
for model in list(PRED_FILES):
for model in [*PRED_FILES, "CHGNet + M3GNet"]:
df_metrics[model] = stable_metrics(
df_wbm[each_true_col],
df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col],
df_preds[each_true_col],
df_preds[each_true_col] + df_preds[model] - df_preds[e_form_col],
)

# pick F1 as primary metric to sort by
Expand All @@ -128,10 +130,12 @@ def load_df_wbm_preds(
# dataframe of all models' energy above convex hull (EACH) predictions (eV/atom)
df_each_pred = pd.DataFrame()
for model in df_metrics.T.MAE.sort_values().index:
df_each_pred[model] = df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col]
df_each_pred[model] = (
df_preds[each_true_col] + df_preds[model] - df_preds[e_form_col]
)


# dataframe of all models' errors in their EACH predictions (eV/atom)
df_each_err = pd.DataFrame()
for model in df_metrics.T.MAE.sort_values().index:
df_each_err[model] = df_wbm[model] - df_wbm[e_form_col]
df_each_err[model] = df_preds[model] - df_preds[e_form_col]
8 changes: 4 additions & 4 deletions models/chgnet/join_chgnet_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from matbench_discovery import today
from matbench_discovery.data import as_dict_handler
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.preds import df_wbm, e_form_col
from matbench_discovery.preds import df_preds, e_form_col

__author__ = "Janosh Riebesell"
__date__ = "2023-03-01"
Expand Down Expand Up @@ -49,18 +49,18 @@

# %% compute corrected formation energies
e_form_chgnet_col = "e_form_per_atom_chgnet"
df_chgnet["formula"] = df_wbm.formula
df_chgnet["formula"] = df_preds.formula
df_chgnet[e_form_chgnet_col] = [
get_e_form_per_atom(dict(energy=ene, composition=formula))
for formula, ene in tqdm(
df_chgnet.set_index("formula").chgnet_energy.items(), total=len(df_chgnet)
)
]
df_wbm[e_form_chgnet_col] = df_chgnet[e_form_chgnet_col]
df_preds[e_form_chgnet_col] = df_chgnet[e_form_chgnet_col]


# %%
ax = density_scatter(df=df_wbm, x=e_form_col, y=e_form_chgnet_col)
ax = density_scatter(df=df_preds, x=e_form_col, y=e_form_chgnet_col)


# %%
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Package = "https://pypi.org/project/matbench-discovery"
[project.optional-dependencies]
test = ["pytest", "pytest-cov", "pytest-markdown-docs"]
running-models = ["aviary", "m3gnet", "maml", "megnet"]
3d-structures = ["crystaltoolkit"]

[tool.setuptools.packages]
find = { include = ["matbench_discovery"] }
Expand Down
Loading

0 comments on commit b8a18d8

Please sign in to comment.