Skip to content

Commit

Permalink
add test_hist_classified_stable_as_func_of_hull_dist()
Browse files Browse the repository at this point in the history
refactor plot funcs to use e_above_hull_pred and e_above_hull_true as main inputs
 rewrite plot scripts to match new signature
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 17df9d0 commit c855e51
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 193 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pretrained/

# Weights and Biases logs
wandb/
job-logs/

# slurm logs
slurm-*out
24 changes: 15 additions & 9 deletions mb_discovery/m3gnet/join_and_plot_m3gnet_relax_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import gzip
import io
import pickle
import urllib.request
from datetime import datetime
from glob import glob
from urllib.request import urlopen

import pandas as pd
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram, PDEntry
Expand All @@ -15,7 +15,7 @@

from mb_discovery import ROOT, as_dict_handler
from mb_discovery.plot_scripts.plot_funcs import (
hist_classify_stable_as_func_of_hull_dist,
hist_classified_stable_as_func_of_hull_dist,
)


Expand All @@ -25,7 +25,7 @@
# %%
task_type = "RS3RE"
date = "2022-08-19"
glob_pattern = f"{date}-m3gnet-relax-wbm-{task_type}/*.json.gz"
glob_pattern = f"{date}-m3gnet-wbm-relax-{task_type}/*.json.gz"
file_paths = sorted(glob(f"{ROOT}/data/{glob_pattern}"))
print(f"Found {len(file_paths):,} files for {glob_pattern = }")

Expand Down Expand Up @@ -68,7 +68,7 @@
# %%
# 2022-01-25-ppd-mp+wbm.pkl.gz (235 MB)
ppd_pickle_url = "https://figshare.com/ndownloader/files/36669624"
zipped_file = urlopen(ppd_pickle_url)
zipped_file = urllib.request.urlopen(ppd_pickle_url)

ppd_mp_wbm: PatchedPhaseDiagram = pickle.load(
io.BytesIO(gzip.decompress(zipped_file.read()))
Expand Down Expand Up @@ -114,15 +114,21 @@


# %%
out_path = f"{ROOT}/data/{today}-m3gnet-relax-wbm-{task_type}.json.gz"
out_path = f"{ROOT}/data/{today}-m3gnet-wbm-relax-{task_type}.json.gz"
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)

out_path = f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
df_m3gnet = pd.read_json(out_path)


# %%
ax_hull_dist_hist = hist_classify_stable_as_func_of_hull_dist(
formation_energy_targets=df_m3gnet.e_form_ppd_2022_01_25,
formation_energy_preds=df_m3gnet.e_form_m3gnet_from_ppd,
e_above_hull_vals=df_m3gnet.e_above_mp_hull,
df_m3gnet["e_above_hull_pred"] = ( # TODO fix this incorrect e_above_hull_pred
df_m3gnet["e_form_m3gnet_from_ppd"] - df_m3gnet["e_above_mp_hull"]
)

ax_hull_dist_hist = hist_classified_stable_as_func_of_hull_dist(
e_above_hull_pred=df_m3gnet.e_above_hull_pred,
e_above_hull_true=df_m3gnet.e_above_mp_hull,
)

# ax_hull_dist_hist.figure.savefig(f"{ROOT}/plots/{today}-m3gnet-wbm-hull-dist-hist.pdf")
8 changes: 4 additions & 4 deletions mb_discovery/m3gnet/slurm_array_m3gnet_relax_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import m3gnet
import numpy as np
import pandas as pd
import wandb
from m3gnet.models import Relaxer

import wandb
from mb_discovery import ROOT, as_dict_handler


Expand All @@ -20,7 +20,7 @@
```sh
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-101 \
--time 3:0:0 --job-name m3gnet-relax-wbm-RS2RE --mem 12000 \
--time 3:0:0 --job-name m3gnet-wbm-relax-RS2RE --mem 12000 \
--output mb_discovery/m3gnet/slurm_logs/slurm-%A-%a.out \
--wrap "python mb_discovery/m3gnet/slurm_array_m3gnet_relax_wbm.py"
```
Expand Down Expand Up @@ -48,7 +48,7 @@
print(f"{job_array_id=}")

today = f"{datetime.now():%Y-%m-%d}"
out_dir = f"{ROOT}/data/{today}-m3gnet-relax-wbm-{task_type}"
out_dir = f"{ROOT}/data/{today}-m3gnet-wbm-relax-{task_type}"
os.makedirs(out_dir, exist_ok=True)
json_out_path = f"{out_dir}/{job_array_id}.json.gz"

Expand Down Expand Up @@ -77,7 +77,7 @@
wandb.login()
wandb.init(
project="m3gnet", # run will be added to this project
name=f"m3gnet-relax-wbm-{task_type}-{job_id}-{job_array_id}",
name=f"m3gnet-wbm-relax-{task_type}-{job_id}-{job_array_id}",
config=run_params,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# %%
from datetime import datetime
from typing import Literal

import matplotlib.pyplot as plt
import pandas as pd

from mb_discovery import ROOT
from mb_discovery.plot_scripts.plot_funcs import (
StabilityCriterion,
WhichEnergy,
hist_classified_stable_as_func_of_hull_dist,
)

Expand Down Expand Up @@ -55,27 +56,34 @@
assert all(nan_counts == 0), f"df should not have missing values: {nan_counts}"

target_col = "e_form_target"
stability_crit: Literal["energy", "energy+std", "energy-std"] = "energy"
energy_type: Literal["true", "pred"] = "true"

stability_crit: StabilityCriterion = "energy"
which_energy: WhichEnergy = "true"

if "std" in stability_crit:
# TODO column names to compute standard deviation from are currently hardcoded
# needs to be updated when adding non-aviary models with uncertainty estimation
var_aleatoric = (df.filter(like="_ale_") ** 2).mean(axis=1)
var_epistemic = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
std_total = (var_epistemic + var_aleatoric) ** 0.5
else:
std_total = None

# make sure we average the expected number of ensemble member predictions
pred_cols = df.filter(regex=r"_pred_\d").columns
assert len(pred_cols) == 10

ax = hist_classified_stable_as_func_of_hull_dist(
df,
target_col,
pred_cols,
e_above_hull_col="e_above_mp_hull",
energy_type=energy_type,
e_above_hull_pred=df[pred_cols].mean(axis=1) - df[target_col],
e_above_hull_true=df.e_above_mp_hull,
which_energy=which_energy,
stability_crit=stability_crit,
std_pred=std_total,
)

ax.figure.set_size_inches(10, 9)

ax.legend(loc="upper left", frameon=False)

fig_name = f"wren-wbm-hull-dist-hist-{energy_type=}-{stability_crit=}"
fig_name = f"wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}"
img_path = f"{ROOT}/figures/{today}-{fig_name}.pdf"
# plt.savefig(img_path)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from mb_discovery import ROOT
from mb_discovery.plot_scripts.plot_funcs import (
StabilityCriterion,
WhichEnergy,
hist_classified_stable_as_func_of_hull_dist,
)

Expand Down Expand Up @@ -50,34 +52,39 @@


# %%
energy_type = "true"
stability_crit = "energy"
which_energy: WhichEnergy = "true"
stability_crit: StabilityCriterion = "energy"
df["wbm_batch"] = df.index.str.split("-").str[2]
fig, axs = plt.subplots(2, 3, figsize=(18, 9))

# make sure we average the expected number of ensemble member predictions
pred_cols = df.filter(regex=r"_pred_\d").columns
assert len(pred_cols) == 10

common_kwargs = dict(
target_col="e_form_target",
pred_cols=pred_cols,
energy_type=energy_type,
stability_crit=stability_crit,
e_above_hull_col="e_above_mp_hull",
)

for (batch_idx, batch_df), ax in zip(df.groupby("wbm_batch"), axs.flat):
hist_classified_stable_as_func_of_hull_dist(batch_df, ax=ax, **common_kwargs)
hist_classified_stable_as_func_of_hull_dist(
e_above_hull_pred=batch_df[pred_cols].mean(axis=1) - batch_df.e_form_target,
e_above_hull_true=batch_df.e_above_mp_hull,
which_energy=which_energy,
stability_crit=stability_crit,
ax=ax,
)

title = f"Batch {batch_idx} ({len(df):,})"
ax.set(title=title)


hist_classified_stable_as_func_of_hull_dist(df, ax=axs.flat[-1], **common_kwargs)
hist_classified_stable_as_func_of_hull_dist(
e_above_hull_pred=df[pred_cols].mean(axis=1),
e_above_hull_true=df.e_above_mp_hull,
which_energy=which_energy,
stability_crit=stability_crit,
ax=axs.flat[-1],
)

axs.flat[-1].set(title=f"Combined {batch_idx} ({len(df):,})")
axs.flat[0].legend(frameon=False, loc="upper left")

img_name = f"{today}-wren-wbm-hull-dist-hist-{energy_type=}-{stability_crit=}.pdf"
img_name = f"{today}-wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}.pdf"
# plt.savefig(f"{ROOT}/figures/{img_name}")
Loading

0 comments on commit c855e51

Please sign in to comment.