Skip to content

Commit

Permalink
add models/bowsr/join_bowsr_results.py
Browse files Browse the repository at this point in the history
and fix file paths to models/m3gnet/*-results.json.gz
  • Loading branch information
janosh committed Jun 20, 2023
1 parent e6bf955 commit e7f4582
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
dfs[model_name] = df

dfs["M3GNet"] = pd.read_json(
f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
).set_index("material_id")

dfs["Wrenformer"] = pd.read_csv(
Expand Down
92 changes: 92 additions & 0 deletions models/bowsr/join_bowsr_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# %%
from __future__ import annotations

import os
from datetime import datetime
from glob import glob

import pandas as pd
from pymatgen.core import Structure
from tqdm import tqdm

from mb_discovery import ROOT, as_dict_handler
from mb_discovery.plots import hist_classified_stable_as_func_of_hull_dist

__author__ = "Janosh Riebesell"
__date__ = "2022-09-22"

today = f"{datetime.now():%Y-%m-%d}"


# %%
module_dir = os.path.dirname(__file__)
task_type = "IS2RE"
date = "2022-09-22"
glob_pattern = f"{date}-bowsr-wbm-{task_type}/*.json.gz"
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
print(f"Found {len(file_paths):,} files for {glob_pattern = }")

dfs: dict[str, pd.DataFrame] = {}


# %%
# 2022-08-16 tried multiprocessing.Pool() to load files in parallel but was somehow
# slower than serial loading
for file_path in tqdm(file_paths):
if file_path in dfs:
continue
# keep whole dataframe in memory
df = pd.read_json(file_path).set_index("material_id")
col_map = dict(
structure_pred="structure_bowsr",
energy_pred="energy_bowsr",
e_form_per_atom_pred="e_form_per_atom_bowsr",
)
df = df.rename(columns=col_map)
df["structure_bowsr"] = df.structure_bowsr.map(Structure.from_dict)
df["formula"] = df.structure_bowsr.map(lambda x: x.formula)
df["volume"] = df.structure_bowsr.map(lambda x: x.volume)
df["n_sites"] = df.structure_bowsr.map(len)
dfs[file_path] = df


# %%
df_bowsr = pd.concat(dfs.values())


# %%
df_wbm = pd.read_csv( # download wbm-steps-summary.csv (23.31 MB)
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
).set_index("material_id")

df_bowsr["e_form_wbm"] = df_wbm.e_form_per_atom


# %%
df_bowsr.hist(bins=200, figsize=(18, 12))
df_bowsr.isna().sum()


# %%
out_path = f"{ROOT}/models/bowsr/{today}-bowsr-wbm-{task_type}.json.gz"
df_bowsr.reset_index().to_json(out_path, default_handler=as_dict_handler)

out_path = f"{ROOT}/models/bowsr/2022-08-16-bowsr-wbm-IS2RE.json.gz"
df_bowsr = pd.read_json(out_path).set_index("material_id")


# %%
df_hull = pd.read_csv(
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
).set_index("material_id")
df_bowsr["e_above_mp_hull"] = df_hull.e_above_mp_hull
df_bowsr["e_above_hull_pred"] = ( # TODO fix this incorrect e_above_hull_pred
df_bowsr["e_form_per_atom_bowsr"] - df_bowsr["e_above_mp_hull"]
)

ax_hull_dist_hist = hist_classified_stable_as_func_of_hull_dist(
e_above_hull_pred=df_bowsr.e_above_hull_pred,
e_above_hull_true=df_bowsr.e_above_mp_hull,
)

# ax_hull_dist_hist.figure.savefig(f"{ROOT}/plots/{today}-bowsr-wbm-hull-dist-hist.pdf")
8 changes: 4 additions & 4 deletions models/bowsr/slurm_array_bowsr_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@
with open(os.devnull, "w") as devnull, contextlib.redirect_stdout(devnull):
bayes_optimizer.optimize(**optimize_kwargs)

structure_pred, energy_pred = bayes_optimizer.get_optimized_structure_and_energy()
structure_bowsr, energy_bowsr = bayes_optimizer.get_optimized_structure_and_energy()

results = dict(
e_form_per_atom_pred=model.predict_energy(structure),
structure_pred=structure_pred,
energy_pred=energy_pred,
e_form_per_atom_bowsr=model.predict_energy(structure),
structure_bowsr=structure_bowsr,
energy_bowsr=energy_bowsr,
)

relax_results[material_id] = results
Expand Down
6 changes: 3 additions & 3 deletions models/m3gnet/eda_wbm_pre_vs_post_m3gnet_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@

# %%
df_m3gnet_is2re = pd.read_json(
f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
).set_index("material_id")
df_m3gnet_rs2re = pd.read_json(
f"{ROOT}/data/2022-08-19-m3gnet-wbm-relax-results-RS3RE.json.gz"
f"{ROOT}/models/m3gnet/2022-08-19-m3gnet-wbm-relax-results-RS2RE.json.gz"
).set_index("material_id")


Expand Down Expand Up @@ -226,5 +226,5 @@
# %% write df back to compressed JSON
# filter out columns containing 'rs2re'
# df_m3gnet_is2re.reset_index().filter(regex="^((?!rs2re).)*$").to_json(
# f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE-2.json.gz"
# f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE-2.json.gz"
# ).set_index("material_id")
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import gzip
import io
import os
import pickle
import urllib.request
from datetime import datetime
Expand All @@ -16,14 +17,18 @@
from mb_discovery import ROOT, as_dict_handler
from mb_discovery.plots import hist_classified_stable_as_func_of_hull_dist

__author__ = "Janosh Riebesell"
__date__ = "2022-08-16"

today = f"{datetime.now():%Y-%m-%d}"


# %%
task_type = "RS3RE"
module_dir = os.path.dirname(__file__)
task_type = "RS2RE"
date = "2022-08-19"
glob_pattern = f"{date}-m3gnet-wbm-relax-{task_type}/*.json.gz"
file_paths = sorted(glob(f"{ROOT}/data/{glob_pattern}"))
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
print(f"Found {len(file_paths):,} files for {glob_pattern = }")

dfs: dict[str, pd.DataFrame] = {}
Expand Down Expand Up @@ -93,7 +98,7 @@
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
).set_index("material_id")

df_m3gnet["e_form_wbm"] = df_wbm.e_form
df_m3gnet["e_form_wbm"] = df_wbm.e_form_per_atom
df_m3gnet["wbm_energy"] = df_wbm.energy

pd_entries_wbm = [
Expand All @@ -111,11 +116,11 @@


# %%
out_path = f"{ROOT}/data/{today}-m3gnet-wbm-relax-{task_type}.json.gz"
out_path = f"{ROOT}/models/m3gnet/{today}-m3gnet-wbm-relax-{task_type}.json.gz"
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)

out_path = f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
df_m3gnet = pd.read_json(out_path)
out_path = f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
df_m3gnet = pd.read_json(out_path).set_index("material_id")


# %%
Expand Down

0 comments on commit e7f4582

Please sign in to comment.