Skip to content

Commit

Permalink
use bin counts directly (no KDE) in hull dist density scatter plot
Browse files Browse the repository at this point in the history
add rolling-mae-vs-hull-dist-wbm-batches-{alignn,bowsr,mace,voronoi-rf}.svelte
fix hull-dist-scatter-wrenformer-failures.pdf (used wrong input dataframe df_preds vs df_each_pred)
add ref wang_framework_2021 (MP2020 correction scheme)
  • Loading branch information
janosh committed Aug 28, 2023
1 parent d9bb043 commit 5df80ef
Show file tree
Hide file tree
Showing 21 changed files with 171 additions and 90 deletions.
2 changes: 1 addition & 1 deletion data/wbm/eda.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
save_fig(ptable, f"{PDF_FIGS}/{dataset}-element-{count_mode}-counts.pdf")


# %% histogram of energy above MP convex hull for WBM
# %% histogram of energy distance to MP convex hull for WBM
col = each_true_col # or e_form_col
mean, std = df_wbm[col].mean(), df_wbm[col].std()

Expand Down
16 changes: 10 additions & 6 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,14 @@ def rolling_mae_vs_hull_dist(
fig.set(xlim=x_lim, ylim=y_lim)
line_styles = "- -- -. :".split()
markers = "o s ^ v D * p X".split()
combinations = [(ls, mark) for mark in markers for ls in line_styles]

for idx, line in enumerate(fig.lines):
ls, marker = combinations[idx % len(combinations)]
line_label = line.get_label()
if line_label.startswith("_"):
continue
ls, marker = line_styles[idx], markers[idx]
line.set(ls=ls, marker=marker, markeredgewidth=0.5, markeredgecolor="black")
line.set_markevery(4)
line.set_markevery(8)

elif backend == "plotly":
for idx, model in enumerate(df_rolling_err if with_sem else []):
Expand Down Expand Up @@ -614,14 +616,16 @@ def rolling_mae_vs_hull_dist(
)
fig.add_shape(type="rect", x0=x0, y0=y0, x1=x0 - window, y1=y0 + window / 5)

line_styles = "solid dash dot dashdot".split()
markers = "circle square triangle-up triangle-down diamond cross star x".split()
from matbench_discovery.preds import model_styles

for trace in fig.data:
for idx, trace in enumerate(fig.data):
if style := model_styles.get(trace.name):
ls, _marker, color = style
trace.line = dict(color=color, dash=ls, width=2)
else:
trace.line = dict(
color=plotly_colors[idx], dash=plotly_line_styles[idx], width=3
)
# marker_spacing = 2
# trace = go.Scatter(
# x=trace.x[::marker_spacing],
Expand Down
54 changes: 51 additions & 3 deletions models/wrenformer/analyze_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@


# %%
import numpy as np
import pandas as pd
from aviary.wren.utils import get_isopointal_proto_from_aflow
from pymatviz import spacegroup_hist, spacegroup_sunburst
from pymatviz.ptable import ptable_heatmap_plotly
from pymatviz.utils import save_fig
from pymatviz.utils import add_identity_line, bin_df_cols, save_fig

from matbench_discovery import PDF_FIGS, SITE_FIGS
from matbench_discovery.data import DATA_FILES, df_wbm
Expand Down Expand Up @@ -80,10 +81,14 @@

# %%
fig = spacegroup_sunburst(df_bad[spg_col], width=350, height=350)
fig.layout.title.update(text=f"Spacegroup sunburst for {title}", x=0.5, font_size=14)
# fig.layout.title.update(text=f"Spacegroup sunburst for {title}", x=0.5, font_size=14)
fig.layout.margin.update(l=1, r=1, t=1, b=1)
fig.show()


# %%
save_fig(fig, f"{PDF_FIGS}/spacegroup-sunburst-{model.lower()}-failures.pdf")
# save_fig(fig, f"{FIGS}/spacegroup-sunburst-{model}-failures.svelte")
save_fig(fig, f"{SITE_FIGS}/spacegroup-sunburst-{model}-failures.svelte")


# %%
Expand All @@ -92,3 +97,46 @@
fig.layout.margin = dict(l=0, r=0, t=50, b=0)
fig.show()
save_fig(fig, f"{PDF_FIGS}/elements-{model.lower()}-failures.pdf")


# %%
model = "Wrenformer"
cols = [model, each_true_col]
bin_cnt_col = "bin counts"
df_bin = bin_df_cols(
df_each_pred, [each_true_col, model], n_bins=200, bin_counts_col=bin_cnt_col
)
log_cnt_col = f"log {bin_cnt_col}"
df_bin[log_cnt_col] = np.log1p(df_bin[bin_cnt_col]).round(2)


# %%
fig = df_bin.reset_index().plot.scatter(
x=each_true_col,
y=model,
hover_data=cols,
hover_name=df_preds.index.name,
backend="plotly",
color=log_cnt_col,
color_continuous_scale="turbo",
)

# title = "Analysis of Wrenformer failure cases in the highlighted rectangle"
# fig.layout.title.update(text=title, x=0.5)
fig.layout.margin.update(l=0, r=0, t=0, b=0)
fig.layout.legend.update(title="", x=1, y=0, xanchor="right")
add_identity_line(fig)
fig.layout.coloraxis.colorbar.update(
x=1, y=0.5, xanchor="right", thickness=12, title=""
)
# add shape shaded rectangle at x < 1, y > 1
fig.add_shape(
type="rect", **dict(x0=1, y0=1, x1=-1, y1=6), fillcolor="gray", opacity=0.2
)
fig.show()


# %%
img_name = "hull-dist-scatter-wrenformer-failures"
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=600, height=300)
77 changes: 19 additions & 58 deletions scripts/model_figs/scatter_hull_dist_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@

import numpy as np
import plotly.express as px
import scipy.stats
from pymatviz.utils import add_identity_line, bin_df_cols, save_fig
from tqdm import tqdm

from matbench_discovery import PDF_FIGS, SITE_FIGS
from matbench_discovery.metrics import classify_stable
from matbench_discovery.plots import clf_color_map, clf_colors, clf_labels
from matbench_discovery.plots import clf_colors
from matbench_discovery.preds import (
df_metrics,
df_preds,
Expand Down Expand Up @@ -47,22 +44,28 @@
df_melt[each_pred_col] = (
df_melt[each_true_col] + df_melt[e_form_pred_col] - df_melt[e_form_col]
)
df_bin = bin_df_cols(df_melt, [each_true_col, each_pred_col], [facet_col], n_bins=200)
df_bin = bin_df_cols(
df_melt,
bin_by_cols=[each_true_col, each_pred_col],
group_by_cols=[facet_col],
n_bins=200,
bin_counts_col=(bin_cnt_col := "bin counts"),
)
df_bin = df_bin.reset_index()

# sort legend and facet plots by MAE
legend_order = list(df_metrics.T.MAE.sort_values().index)


# determine each point's classification to color them by
true_pos, false_neg, false_pos, true_neg = classify_stable(
df_bin[each_true_col], df_bin[each_pred_col]
)

clf_col = "classified"
df_bin[clf_col] = np.array(clf_labels)[
true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3
]
# now unused, can be used to color points by TP/FP/TN/FN
# true_pos, false_neg, false_pos, true_neg = classify_stable(
# df_bin[each_true_col], df_bin[each_pred_col]
# )
# clf_col = "classified"
# df_bin[clf_col] = np.array(clf_labels)[
# true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3
# ]


# %% scatter plot of actual vs predicted e_form_per_atom
Expand Down Expand Up @@ -121,21 +124,8 @@


# %%
clr_col, cnt_col = "density", "counts"
# compute KDE for each model's predictions separately
for model in (pbar := tqdm(models)):
pbar.set_description(f"KDE for {model=}")

xy = df_preds[[each_true_col, model]].dropna().T
model_kde = scipy.stats.gaussian_kde(xy)

model_rows = df_bin[df_bin[facet_col] == model]
xy_binned = model_rows[[each_true_col, each_pred_col]].T
density = model_kde(xy_binned)
n_preds = len(df_preds[model].dropna())
df_bin.loc[model_rows.index, cnt_col] = density / density.sum() * n_preds

df_bin[clr_col] = np.log1p(df_bin[cnt_col]).round(2)
log_bin_cnt_col = f"log {bin_cnt_col}"
df_bin[log_bin_cnt_col] = np.log1p(df_bin[bin_cnt_col]).round(2)


# %% scatter plot of DFT vs predicted hull distance with each model in separate subplot
Expand All @@ -148,7 +138,7 @@
y=each_pred_col,
facet_col=facet_col,
facet_col_wrap=n_cols,
color=clr_col,
color=log_bin_cnt_col,
facet_col_spacing=0.02,
facet_row_spacing=0.04,
hover_data=hover_cols,
Expand Down Expand Up @@ -259,32 +249,3 @@
fig_name = f"each-scatter-models-{n_rows}x{n_cols}"
save_fig(fig, f"{SITE_FIGS}/{fig_name}.svelte")
save_fig(fig, f"{PDF_FIGS}/{fig_name}.pdf")


# %%
model = "Wrenformer"
fig = px.scatter(
df_bin.query(f"{facet_col} == {model!r}"),
x=each_true_col,
y=each_pred_col,
hover_data=hover_cols,
color=clf_col,
color_discrete_map=clf_color_map,
hover_name=df_preds.index.name,
opacity=0.7,
)

title = "Analysis of Wrenformer failure cases in the highlighted rectangle"
fig.layout.title.update(text=title, x=0.5)
fig.layout.legend.update(title="", x=1, y=0, xanchor="right")
add_identity_line(fig)

# add shape shaded rectangle at x < 1, y > 1
fig.add_shape(
type="rect", **dict(x0=1, y0=1, x1=-1, y1=6), fillcolor="gray", opacity=0.2
)
fig.show()

img_name = "hull-dist-scatter-wrenformer-failures"
# save_fig(fig, f"{FIGS}/{img_name}.svelte")
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf")
20 changes: 14 additions & 6 deletions scripts/rolling_mae_vs_hull_dist_wbm_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

from matbench_discovery import PDF_FIGS, SITE_FIGS, today
from matbench_discovery.plots import plt, rolling_mae_vs_hull_dist
from matbench_discovery.preds import df_each_pred, df_preds, e_form_col, each_true_col
from matbench_discovery.preds import (
df_each_pred,
df_preds,
e_form_col,
each_true_col,
models,
)

__author__ = "Rhys Goodall, Janosh Riebesell"
__date__ = "2022-06-18"
Expand Down Expand Up @@ -39,7 +45,7 @@
markevery=20,
markerfacecolor="white",
markeredgewidth=2.5,
backend="matplotlib",
backend="matplotlib", # don't change, code here not plotly compatible
ax=ax,
just_plot_lines=idx > 1,
pbar=False,
Expand All @@ -54,7 +60,7 @@


# %% plotly
for model in list(df_each_pred)[:-2]:
for model in models:
df_pivot = df_each_pred.pivot(columns=batch_col, values=model)

fig, df_err, df_std = rolling_mae_vs_hull_dist(
Expand All @@ -66,9 +72,11 @@
show_dummy_mae=False,
with_sem=False,
)
fig.layout.legend.update(title=f"<b>{model}</b>", x=0.02, y=0.02)
fig.layout.legend.update(
title=f"<b>{model}</b>", x=0.02, y=0.02, bgcolor="rgba(0,0,0,0)"
)
fig.layout.margin.update(l=10, r=10, b=10, t=10)
fig.update_layout(hovermode="x unified", hoverlabel_bgcolor="black")
fig.layout.update(hovermode="x unified", hoverlabel_bgcolor="black")
fig.update_traces(
hovertemplate="y=%{y:.3f} eV",
selector=lambda trace: trace.name.startswith("Batch"),
Expand All @@ -78,4 +86,4 @@
model_snake_case = model.lower().replace(" + ", "-").replace(" ", "-")
img_path = f"rolling-mae-vs-hull-dist-wbm-batches-{model_snake_case}"
save_fig(fig, f"{SITE_FIGS}/{img_path}.svelte")
save_fig(fig, f"{PDF_FIGS}/{img_path}.pdf")
save_fig(fig, f"{PDF_FIGS}/{img_path}.pdf", width=500, height=330)
2 changes: 1 addition & 1 deletion site/src/figs/each-scatter-models-5x2.svelte

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion site/src/figs/hull-dist-scatter-wrenformer-failures.svelte

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions site/src/routes/preprint/references.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2648,6 +2648,62 @@ references:
URL: https://www.nature.com/articles/s41467-020-18556-9
volume: '11'

- id: wang_framework_2021
abstract: >-
In this work, we demonstrate a method to quantify uncertainty in corrections
to density functional theory (DFT) energies based on empirical results. Such
corrections are commonly used to improve the accuracy of computational
enthalpies of formation, phase stability predictions, and other
energy-derived properties, for example. We incorporate this method into a
new DFT energy correction scheme comprising a mixture of oxidation-state and
composition-dependent corrections and show that many chemical systems
contain unstable polymorphs that may actually be predicted stable when
uncertainty is taken into account. We then illustrate how these
uncertainties can be used to estimate the probability that a compound is
stable on a compositional phase diagram, thus enabling better-informed
assessments of compound stability.
accessed:
- year: 2023
month: 8
day: 28
author:
- family: Wang
given: Amanda
- family: Kingsbury
given: Ryan
- family: McDermott
given: Matthew
- family: Horton
given: Matthew
- family: Jain
given: Anubhav
- family: Ong
given: Shyue Ping
- family: Dwaraknath
given: Shyam
- family: Persson
given: Kristin A.
citation-key: wang_framework_2021
container-title: Scientific Reports
container-title-short: Sci Rep
DOI: 10.1038/s41598-021-94550-5
ISSN: 2045-2322
issue: '1'
issued:
- year: 2021
month: 7
day: 29
language: en
license: 2021 The Author(s)
number: '1'
page: '15496'
publisher: Nature Publishing Group
source: www.nature.com
title: A framework for quantifying uncertainty in DFT energy corrections
type: article-journal
URL: https://www.nature.com/articles/s41598-021-94550-5
volume: '11'

- id: wang_predicting_2021
abstract: >-
We propose an efficient high-throughput scheme for the discovery of stable
Expand Down
12 changes: 6 additions & 6 deletions site/src/routes/si/+page.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ The figures below show the rolling MAE as a function of distance to the convex h
{#if mounted}

<div style="display: grid; grid-template-columns: 1fr 1fr; margin: 0 -1em 0 -4em;">
<M3gnetRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
<CHGNetRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
<WrenformerRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
<MegnetRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
<VoronoiRfRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
<CgcnnRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
<M3gnetRollingMaeBatches style="aspect-ratio: 1.2;" />
<CHGNetRollingMaeBatches style="aspect-ratio: 1.2;" />
<WrenformerRollingMaeBatches style="aspect-ratio: 1.2;" />
<MegnetRollingMaeBatches style="aspect-ratio: 1.2;" />
<VoronoiRfRollingMaeBatches style="aspect-ratio: 1.2;" />
<CgcnnRollingMaeBatches style="aspect-ratio: 1.2;" />
</div>
{/if}

Expand Down

0 comments on commit 5df80ef

Please sign in to comment.