Skip to content

Commit

Permalink
add toggle to switch between formation energy and convex hull distanc…
Browse files Browse the repository at this point in the history
…e in scripts/scatter_energy_models.py

use log2-spaced sampling in cumulative metrics plots to achieve higher fidelity at equal file size in initial part of the discovery campaign
add site/src/figs/e-form-scatter-models-5x2.svelte
  • Loading branch information
janosh committed Nov 29, 2023
1 parent 13cbb90 commit 82d07f1
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 53 deletions.
6 changes: 3 additions & 3 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def cumulative_metrics(
optimal_recall: str | None = "Optimal Recall",
show_n_stable: bool = True,
backend: Backend = "plotly",
n_points: int = 50,
n_points: int = 100,
**kwargs: Any,
) -> tuple[plt.Figure | go.Figure, pd.DataFrame]:
"""Create 2 subplots side-by-side with cumulative precision and recall curves for
Expand Down Expand Up @@ -595,7 +595,7 @@ def cumulative_metrics(
backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
Changes the return type. Defaults to 'plotly'.
n_points (int, optional): Number of points to use for interpolation of the
metric curves. Defaults to 80.
metric curves. Defaults to 100.
**kwargs: Keyword arguments passed to df.plot().
Returns:
Expand All @@ -606,7 +606,7 @@ def cumulative_metrics(

# largest number of materials predicted stable by any model, determines x-axis range
n_max_pred_stable = (df_preds < stability_threshold).sum().max()
longest_xs = np.linspace(0, n_max_pred_stable - 1, n_points)
longest_xs = np.logspace(0, np.log2(n_max_pred_stable - 1), n_points, base=2)
for metric in metrics:
dfs[metric].index = longest_xs

Expand Down
6 changes: 2 additions & 4 deletions scripts/model_figs/cumulative_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,9 @@
fig.layout[key].range = range_y

fig.layout.margin.update(l=0, r=0, t=30, b=50)
# use annotation for x-axis label
fig.add_annotation(
x=0.5,
y=-0.15,
xref="paper",
yref="paper",
**dict(x=0.5, y=-0.15, xref="paper", yref="paper"),
text=x_label,
showarrow=False,
font=dict(size=14),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# %%
import math
from typing import Literal

import numpy as np
import plotly.express as px
Expand All @@ -28,6 +29,15 @@
e_form_pred_col = "e_form_per_atom_pred"
legend = dict(x=1, y=0, xanchor="right", yanchor="bottom", title=None)

# toggle between formation energy and energy above convex hull
which_energy: Literal["e-form", "each"] = "each"
if which_energy == "each":
e_pred_col = each_pred_col
e_true_col = each_true_col
if which_energy == "e-form":
e_true_col = e_form_col
e_pred_col = e_form_pred_col


# %%
facet_col = "Model"
Expand All @@ -45,9 +55,10 @@
df_melt[each_pred_col] = (
df_melt[each_true_col] + df_melt[e_form_pred_col] - df_melt[e_form_col]
)

df_bin = bin_df_cols(
df_melt,
bin_by_cols=[each_true_col, each_pred_col],
bin_by_cols=[e_true_col, e_pred_col],
group_by_cols=[facet_col],
n_bins=200,
bin_counts_col=(bin_cnt_col := "bin counts"),
Expand All @@ -61,7 +72,7 @@
# determine each point's classification to color them by
# now unused, can be used to color points by TP/FP/TN/FN
# true_pos, false_neg, false_pos, true_neg = classify_stable(
# df_bin[each_true_col], df_bin[each_pred_col]
# df_bin[e_true_col], df_bin[e_pred_col]
# )
# clf_col = "classified"
# df_bin[clf_col] = np.array(clf_labels)[
Expand Down Expand Up @@ -100,8 +111,8 @@
# %% scatter plot of actual vs predicted e_above_hull
fig = px.scatter(
df_bin,
x=each_true_col,
y=each_pred_col,
x=e_true_col,
y=e_pred_col,
color=facet_col,
hover_data=hover_cols,
hover_name=df_preds.index.name,
Expand Down Expand Up @@ -133,8 +144,8 @@

fig = px.scatter(
df_bin,
x=each_true_col,
y=each_pred_col,
x=e_true_col,
y=e_pred_col,
facet_col=facet_col,
facet_col_wrap=n_cols,
color=log_bin_cnt_col,
Expand All @@ -145,7 +156,7 @@
# color=clf_col,
# color_discrete_map=clf_color_map,
# opacity=0.4,
range_x=(domain := (-4, 7)),
range_x=(domain := (-4, 7) if which_energy == "each" else (None, None)),
range_y=domain,
category_orders={facet_col: legend_order},
# pick from https://plotly.com/python/builtin-colorscales
Expand Down Expand Up @@ -179,41 +190,52 @@
fig.layout[f"yaxis{idx}"].title.text = ""

# add transparent rectangle with TN, TP, FN, FP labels in each quadrant
for sign_x, sign_y, color, label in zip(
[-1, -1, 1, 1], [-1, 1, -1, 1], clf_colors, ("TP", "FN", "FP", "TN")
):
# instead of coloring points in each quadrant, we can add a transparent
# background to each quadrant (looks worse maybe than coloring points)
# fig.add_shape(
# type="rect",
# x0=0,
# y0=0,
# x1=sign_x * 100,
# y1=sign_y * 100,
# fillcolor=color,
# opacity=0.2,
# layer="below",
# row="all",
# col="all",
# )
fig.add_annotation(
x=(domain[0] if sign_x < 0 else domain[1]),
y=(domain[0] if sign_y < 0 else domain[1]),
xshift=-20 * sign_x,
yshift=-20 * sign_y,
text=label,
showarrow=False,
font=dict(size=16, color=color),
row="all",
col="all",
)

# add dashed quadrant separators
fig.add_vline(x=0, line=dict(width=0.5, dash="dash"))
fig.add_hline(y=0, line=dict(width=0.5, dash="dash"))

fig.update_xaxes(nticks=5)
fig.update_yaxes(nticks=5)
if e_true_col == each_true_col:
# add dashed quadrant separators
fig.add_vline(x=0, line=dict(width=0.5, dash="dash"))
fig.add_hline(y=0, line=dict(width=0.5, dash="dash"))

for sign_x, sign_y, label, color in (
(-1, -1, "TP", clf_colors[0]),
(-1, 1, "FN", clf_colors[1]),
(1, -1, "FP", clf_colors[2]),
(1, 1, "TN", clf_colors[3]),
):
# instead of coloring points in each quadrant, we can add a transparent
# background to each quadrant (looks worse maybe than coloring points)
# fig.add_shape(
# type="rect",
# x0=0,
# y0=0,
# x1=sign_x * 100,
# y1=sign_y * 100,
# fillcolor=color,
# opacity=0.2,
# layer="below",
# row="all",
# col="all",
# )
fig.add_annotation(
x=(domain[0] if sign_x < 0 else domain[1]),
y=(domain[0] if sign_y < 0 else domain[1]),
xshift=-20 * sign_x,
yshift=-15 * sign_y,
text=label,
showarrow=False,
font=dict(size=16, color=color),
row="all",
col="all",
)

# enable grid
fig.update_layout(
xaxis=dict(showgrid=True),
yaxis=dict(showgrid=True),
)

fig.update_xaxes(nticks=8)
fig.update_yaxes(nticks=8)
add_identity_line(fig)

# remove legend title and place legend centered above subplots, increase marker size
fig.layout.legend.update(
Expand All @@ -236,6 +258,7 @@
textangle=-90,
**axis_titles,
)

fig.layout.height = 230 * n_rows
fig.layout.coloraxis.colorbar.update(orientation="h", thickness=9, len=0.5, y=1.05)
# fig.layout.width = 1100
Expand All @@ -246,6 +269,6 @@


# %%
fig_name = f"each-scatter-models-{n_rows}x{n_cols}"
fig_name = f"{which_energy}-scatter-models-{n_rows}x{n_cols}"
save_fig(fig, f"{SITE_FIGS}/{fig_name}.svelte")
save_fig(fig, f"{PDF_FIGS}/{fig_name}.pdf")
2 changes: 1 addition & 1 deletion site/src/figs/cumulative-mae.svelte

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion site/src/figs/cumulative-precision-recall.svelte

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions site/src/figs/e-form-scatter-models-5x2.svelte

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion site/src/figs/each-scatter-models-5x2.svelte

Large diffs are not rendered by default.

0 comments on commit 82d07f1

Please sign in to comment.