Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rough solution to recreating plots for unique proto set #88

Merged
merged 14 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ repos:
- prettier
- prettier-plugin-svelte
- svelte
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/.+\.(yaml|json)|changelog.md)$
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.0.0-alpha.2
Expand Down
29 changes: 19 additions & 10 deletions matbench_discovery/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,16 @@ class Open(LabelEnum):
CSCD = "CSCD", "closed source, closed data"


ev_per_atom = styled_html_tag(
@unique
class TestSubset(LabelEnum):
"""Test set subsets."""

uniq_protos = "uniq_protos", "Unique Structure Prototypes"
ten_k_most_stable = "10k_most_stable", "10k Most Stable"
full = "full", "Full Test Set"


eV_per_atom = styled_html_tag( # noqa: N816
"(eV/atom)", tag="span", style="font-size: 0.8em; font-weight: lighter;"
)

Expand All @@ -149,16 +158,16 @@ class Quantity(LabelEnum):
spg_num = "Space group"
n_wyckoff = "Number of Wyckoff positions"
n_sites = "Number of atoms"
energy_per_atom = f"Energy {ev_per_atom}"
e_form = f"DFT E<sub>form</sub> {ev_per_atom}"
e_above_hull = f"E<sub>hull dist</sub> {ev_per_atom}"
e_above_hull_mp2020_corrected_ppd_mp = f"DFT E<sub>hull dist</sub> {ev_per_atom}"
e_above_hull_pred = f"Predicted E<sub>hull dist</sub> {ev_per_atom}"
e_above_hull_mp = f"E<sub>above MP hull</sub> {ev_per_atom}"
e_above_hull_error = f"Error in E<sub>hull dist</sub> {ev_per_atom}"
energy_per_atom = f"Energy {eV_per_atom}"
e_form = f"DFT E<sub>form</sub> {eV_per_atom}"
e_above_hull = f"E<sub>hull dist</sub> {eV_per_atom}"
e_above_hull_mp2020_corrected_ppd_mp = f"DFT E<sub>hull dist</sub> {eV_per_atom}"
e_above_hull_pred = f"Predicted E<sub>hull dist</sub> {eV_per_atom}"
e_above_hull_mp = f"E<sub>above MP hull</sub> {eV_per_atom}"
e_above_hull_error = f"Error in E<sub>hull dist</sub> {eV_per_atom}"
vol_diff = "Volume difference (A^3)"
e_form_per_atom_mp2020_corrected = f"DFT E<sub>form</sub> {ev_per_atom}"
e_form_per_atom_pred = f"Predicted E<sub>form</sub> {ev_per_atom}"
e_form_per_atom_mp2020_corrected = f"DFT E<sub>form</sub> {eV_per_atom}"
e_form_per_atom_pred = f"Predicted E<sub>form</sub> {eV_per_atom}"
material_id = "Material ID"
band_gap = "Band gap (eV)"
formula = "Formula"
Expand Down
2 changes: 1 addition & 1 deletion matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def rolling_mae_vs_hull_dist(
xanchor="right",
xref="x",
)
fig.add_shape(type="rect", x0=x0, y0=y0, x1=x0 - window, y1=y0 + window / 5)
fig.add_shape(type="rect", x0=x0, y0=y0, x1=x0 - window, y1=y0 + 0.006)

from matbench_discovery.preds import model_styles

Expand Down
2 changes: 1 addition & 1 deletion matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PredFiles(Files):
# alignn_pretrained = "alignn/2023-06-03-mp-e-form-alignn-wbm-IS2RE.csv.gz"
# alignn_ff = "alignn_ff/2023-07-11-alignn-ff-wbm-IS2RE.csv.gz"

gnome = "gnome/2023-11-01-gnome-preds-50076332.csv.gz"
# gnome = "gnome/2023-11-01-gnome-preds-50076332.csv.gz"


# key_map maps model keys to pretty labels
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"pymatviz[export-figs,df-pdf-export]",
"scikit-learn",
"scipy",
"seaborn",
"tqdm",
"wandb",
]
Expand Down
16 changes: 11 additions & 5 deletions scripts/model_figs/analyze_model_disagreement.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,22 @@
import sys

import pandas as pd
from crystal_toolkit.helpers.utils import hook_up_fig_with_struct_viewer
from pymatviz.io import save_fig
from pymatviz.utils import add_identity_line

from matbench_discovery import PDF_FIGS, SITE_FIGS, Key
from matbench_discovery import PDF_FIGS, SITE_FIGS
from matbench_discovery.data import DATA_FILES
from matbench_discovery.enums import Key, TestSubset
from matbench_discovery.preds import df_preds

__author__ = "Janosh Riebesell"
__date__ = "2023-02-15"

test_subset = globals().get("test_subset", TestSubset.full)

if test_subset == TestSubset.uniq_protos:
df_preds = df_preds.query(Key.uniq_proto)


# %% scatter plot of largest model errors vs. DFT hull distance
# while some points lie on a horizontal line of constant error, more follow the identity
Expand Down Expand Up @@ -86,10 +91,11 @@
df_cse = pd.read_json(DATA_FILES.wbm_cses_plus_init_structs).set_index(Key.mat_id)


# %% struct viewer
# only run this in Jupyter Notebook
# %% CTK structure viewer
is_jupyter = "ipykernel" in sys.modules
if is_jupyter:
if is_jupyter: # only run this in Jupyter Notebook
from crystal_toolkit.helpers.utils import hook_up_fig_with_struct_viewer

app = hook_up_fig_with_struct_viewer(
fig,
df_cse,
Expand Down
4 changes: 2 additions & 2 deletions scripts/model_figs/compile_model_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
df_preds,
model_styles,
)
from matbench_discovery.preds import models as all_models

__author__ = "Janosh Riebesell"
__date__ = "2022-11-28"
Expand Down Expand Up @@ -128,8 +129,7 @@
df_tmp[time_col] = df_tmp.filter(like=time_col).sum(axis="columns")

# write model metrics to json for website use
in_both = [*set(df_metrics) & set(df_preds)]
df_tmp["missing_preds"] = df_preds[in_both].isna().sum()
df_tmp["missing_preds"] = df_preds[all_models].isna().sum()
df_tmp["missing_percent"] = [
f"{x / len(df_preds):.2%}" for x in df_tmp.missing_preds
]
Expand Down
10 changes: 9 additions & 1 deletion scripts/model_figs/cumulative_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@
import pandas as pd
from pymatviz.io import save_fig

from matbench_discovery import PDF_FIGS, SITE_FIGS, Key
from matbench_discovery import PDF_FIGS, SITE_FIGS
from matbench_discovery.enums import Key, TestSubset
from matbench_discovery.plots import cumulative_metrics
from matbench_discovery.preds import df_each_pred, df_preds, model_styles, models

__author__ = "Janosh Riebesell, Rhys Goodall"
__date__ = "2022-12-04"


test_subset = globals().get("test_subset", TestSubset.full)

if test_subset == TestSubset.uniq_protos:
df_preds = df_preds.query(Key.uniq_proto)
df_each_pred = df_each_pred.loc[df_preds.index]


# %%
metrics: tuple[str, ...] = globals().get("metrics", ("Precision", "Recall"))
# metrics = ("MAE",)
Expand Down
24 changes: 14 additions & 10 deletions scripts/model_figs/hist_classified_stable_vs_hull_dist_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,28 @@

from pymatviz.io import save_fig

from matbench_discovery import PDF_FIGS, SITE_FIGS, Key, today
from matbench_discovery import PDF_FIGS, SITE_FIGS, today
from matbench_discovery.enums import Key, TestSubset
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist, plt
from matbench_discovery.preds import df_metrics, df_preds
from matbench_discovery.preds import df_metrics, df_metrics_uniq_protos, df_preds

__author__ = "Janosh Riebesell"
__date__ = "2022-12-01"


test_subset = globals().get("test_subset", TestSubset.full)

if test_subset == TestSubset.uniq_protos:
df_preds = df_preds.query(Key.uniq_proto)
df_metrics = df_metrics_uniq_protos


# %%
hover_cols = (
df_preds.index.name,
Key.e_form,
Key.each_true,
Key.formula,
)
hover_cols = (df_preds.index.name, Key.e_form, Key.each_true, Key.formula)
facet_col = "Model"
# sort facet plots by model's F1 scores (optionally only show top n=6)
models = list(df_metrics.T.F1.sort_values().index)[::-1]
# sort models by F1 scores so that facet plots are ordered by model performance
# (optionally only show top n=6)
models = list(df_preds.filter([*df_metrics.sort_values("F1", axis=1)]))[::-1]

df_melt = df_preds.melt(
id_vars=hover_cols,
Expand Down
12 changes: 10 additions & 2 deletions scripts/model_figs/make_hull_dist_box_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@
import seaborn as sns
from pymatviz.io import save_fig

from matbench_discovery import PDF_FIGS, SITE_FIGS, Quantity
from matbench_discovery.preds import df_each_err, models
from matbench_discovery import PDF_FIGS, SITE_FIGS
from matbench_discovery.enums import Key, Quantity, TestSubset
from matbench_discovery.preds import df_each_err, df_preds, models

__author__ = "Janosh Riebesell"
__date__ = "2023-05-25"


test_subset = globals().get("test_subset", TestSubset.full)

if test_subset == TestSubset.uniq_protos:
df_preds = df_preds.query(Key.uniq_proto)
df_each_err = df_each_err.loc[df_preds.index]


# %%
ax = df_each_err[models].plot.box(
showfliers=False,
Expand Down
68 changes: 32 additions & 36 deletions scripts/model_figs/make_metrics_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,33 @@
"M3GNet→MEGNet": "M3GNet",
"CHGNet→MEGNet": "CHGNet",
}
df_met = df_metrics_uniq_protos
df_met.loc[Key.train_size.label] = ""

for model in df_metrics:
model_name = name_map.get(model, model)
if not (model_data := MODEL_METADATA.get(model_name)):
continue
n_structs = model_data["training_set"]["n_structures"]
n_structs_str = si_fmt(n_structs)
train_size_str = si_fmt(n_structs)

if n_materials := model_data["training_set"].get("n_materials"):
n_structs_str += f" <small>({si_fmt(n_materials)})</small>"

for df_m in (df_metrics, df_metrics_10k, df_metrics_uniq_protos):
if Key.train_size not in df_m.index:
df_m.loc[Key.train_size.label] = ""
df_m.loc[Key.train_size.label, model] = n_structs_str
model_params = model_data.get(Key.model_params)
df_m.loc[Key.model_params.label, model] = (
si_fmt(model_params) if isinstance(model_params, int) else model_params
)
for key in (
Key.openness,
Key.model_type,
Key.train_task,
Key.test_task,
Key.targets,
):
default = {Key.openness: Open.OSOD}.get(key, pd.NA)
df_m.loc[key.label, model] = model_data.get(key, default)
train_size_str += f" <small>({si_fmt(n_materials)})</small>"

df_met.loc[Key.train_size.label, model] = train_size_str
model_params = model_data.get(Key.model_params)
df_met.loc[Key.model_params.label, model] = (
si_fmt(model_params) if isinstance(model_params, int) else model_params
)
for key in (
Key.openness,
Key.model_type,
Key.train_task,
Key.test_task,
Key.targets,
):
default = {Key.openness: Open.OSOD}.get(key, pd.NA)
df_met.loc[key.label, model] = model_data.get(key, default)


# %% add dummy classifier results to df_metrics(_10k, _uniq_protos)
Expand Down Expand Up @@ -112,18 +112,18 @@
Key.model_type.label,
Key.targets.label,
]
show_cols = [
*f"F1,DAF,Precision,Accuracy,TPR,TNR,MAE,RMSE,{R2_col}".split(","),
*meta_cols,
]
show_cols = [*f"F1,DAF,Prec,Acc,TPR,TNR,MAE,RMSE,{R2_col}".split(","), *meta_cols]

for label, df in (
("", df_metrics),
("-uniq-protos", df_metrics_uniq_protos),
("-first-10k", df_metrics_10k),
):
df_table = df.rename(index={"R2": R2_col})
df_table.index.name = "Model"
# abbreviate long column names
df = df.rename(index={"R2": R2_col, "Precision": "Prec", "Accuracy": "Acc"})
df.index.name = "Model"
# only keep columns we want to show
df_table = df.T.filter(show_cols)

if make_uip_megnet_comparison:
df_table = df_table.filter(regex="MEGNet|CHGNet|M3GNet") # |Dummy
Expand All @@ -133,37 +133,33 @@
"hint: for make_uip_megnet_comparison, uncomment the lines "
"chgnet_megnet and m3gnet_megnet in PredFiles"
)
df_filtered = df_table.T[show_cols] # only keep columns we want to show

# abbreviate long column names
df_filtered = df_filtered.rename(columns={"Precision": "Prec", "Accuracy": "Acc"})

if "-first-10k" in label:
# hide redundant metrics for first 10k preds (all TPR = 1, TNR = 0)
df_filtered = df_filtered.drop(["TPR", "TNR"], axis="columns")
df_table = df_table.drop(["TPR", "TNR"], axis="columns")
if label != "-uniq-protos": # only show training size and model type once
df_filtered = df_filtered.drop(meta_cols, axis="columns")
df_table = df_table.drop(meta_cols, axis="columns", errors="ignore")

styler = (
df_filtered.style.format(
df_table.style.format(
# render integers without decimal places
dict.fromkeys("TP FN FP TN".split(), "{:,.0f}"),
precision=2, # render floats with 2 decimals
na_rep="", # render NaNs as empty string
)
.background_gradient(
cmap="viridis", subset=list(higher_is_better & {*df_filtered})
cmap="viridis", subset=list(higher_is_better & {*df_table})
)
.background_gradient( # reverse color map if lower=better
cmap="viridis_r", subset=list(lower_is_better & {*df_filtered})
cmap="viridis_r", subset=list(lower_is_better & {*df_table})
)
)
# add up/down arrows to indicate which metrics are better when higher/lower
arrow_suffix = dict.fromkeys(higher_is_better, " ↑") | dict.fromkeys(
lower_is_better, " ↓"
)
styler.relabel_index(
[f"{col}{arrow_suffix.get(col, '')}" for col in df_filtered],
[f"{col}{arrow_suffix.get(col, '')}" for col in df_table],
axis="columns",
)

Expand Down
12 changes: 10 additions & 2 deletions scripts/model_figs/parity_energy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from pymatviz.io import save_fig
from pymatviz.utils import add_identity_line, bin_df_cols

from matbench_discovery import PDF_FIGS, SITE_FIGS, Key
from matbench_discovery import PDF_FIGS, SITE_FIGS
from matbench_discovery.enums import Key, TestSubset
from matbench_discovery.plots import clf_colors
from matbench_discovery.preds import df_metrics, df_preds
from matbench_discovery.preds import df_metrics, df_metrics_uniq_protos, df_preds

__author__ = "Janosh Riebesell"
__date__ = "2022-11-28"
Expand All @@ -32,6 +33,13 @@
e_pred_col = Key.e_form_pred


test_subset = globals().get("test_subset", TestSubset.full)

if test_subset == TestSubset.uniq_protos:
df_preds = df_preds.query(Key.uniq_proto)
df_metrics = df_metrics_uniq_protos


# %%
facet_col = "Model"
hover_cols = (Key.each_true, Key.formula)
Expand Down
Loading
Loading