-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ml_stability/m3gnet/slurm_array_m3gnet_relax_wbm.py
mv ml_stability/m3gnet/{m3gnet_relax_wbm->join_and_plot_m3gnet_relax_results}.py
- Loading branch information
Showing
4 changed files
with
223 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
ml_stability/m3gnet/join_and_plot_m3gnet_relax_results.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# %% | ||
from __future__ import annotations | ||
|
||
import gzip | ||
import io | ||
import pickle | ||
from datetime import datetime | ||
from glob import glob | ||
from urllib.request import urlopen | ||
|
||
import pandas as pd | ||
from diel_frontier.patched_phase_diagram import load_ppd | ||
from pymatgen.analysis.phase_diagram import PDEntry | ||
from pymatgen.core import Structure | ||
from tqdm import tqdm | ||
|
||
from ml_stability import ROOT | ||
from ml_stability.plots.plot_funcs import hist_classify_stable_as_func_of_hull_dist | ||
|
||
|
||
today = f"{datetime.now():%Y-%m-%d}" | ||
|
||
|
||
# %% | ||
glob_pattern = "2022-08-16-m3gnet-wbm-relax-results/*.json.gz" | ||
file_paths = glob(f"{ROOT}/data/{glob_pattern}") | ||
print(f"Found {len(file_paths):,} files for {glob_pattern = }") | ||
|
||
|
||
dfs: dict[str, pd.DataFrame] = {} | ||
# 2022-08-16 tried multiprocessing.Pool() to load files in parallel but was somehow | ||
# slower than serial loading | ||
for file_path in tqdm(file_paths): | ||
if file_path in dfs: | ||
continue | ||
try: | ||
dfs[file_path] = pd.read_json(file_path) | ||
except (ValueError, FileNotFoundError): | ||
# pandas v1.5+ correctly raises FileNotFoundError, below raises ValueError | ||
continue | ||
|
||
|
||
# %% | ||
df_m3gnet = pd.concat(dfs.values()) | ||
df_m3gnet.index.name = "material_id" | ||
if any(df_m3gnet.index.str.contains("_")): | ||
df_m3gnet.index = df_m3gnet.index.str.replace("_", "-") | ||
|
||
df_m3gnet = df_m3gnet.rename( | ||
columns=dict(final_structure="m3gnet_structure", trajectory="m3gnet_trajectory") | ||
) | ||
|
||
df_m3gnet["m3gnet_energy"] = df_m3gnet.trajectory.map(lambda x: x["energies"][-1][0]) | ||
|
||
|
||
# %% | ||
out_file = f"{today}-m3gnet-wbm-relax-results.json.gz" | ||
df_m3gnet.reset_index().to_json(f"{ROOT}/data/{out_file}") | ||
|
||
|
||
# %% | ||
# 2022-01-25-ppd-mp+wbm.pkl.gz (235 MB) | ||
ppd_pickle_url = "https://figshare.com/ndownloader/files/36669624" | ||
zipped_file = urlopen(ppd_pickle_url) | ||
ppd_mp_wbm = pickle.load(io.BytesIO(gzip.decompress(zipped_file.read()))) | ||
|
||
ppd_mp_wbm = load_ppd("ppd-mp+wbm-2022-01-25.pkl.gz") | ||
|
||
|
||
df_m3gnet["m3gnet_structure"] = df_m3gnet.m3gnet_structure.map(Structure.from_dict) | ||
df_m3gnet["pd_entry"] = [ | ||
PDEntry(row.m3gnet_structure.composition, row.m3gnet_energy) | ||
for row in df_m3gnet.itertuples() | ||
] | ||
df_m3gnet["e_form_m3gnet"] = df_m3gnet.pd_entry.map(ppd_mp_wbm.get_form_energy_per_atom) | ||
|
||
|
||
# %% | ||
df_hull = pd.read_csv( | ||
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv" | ||
).set_index("material_id") | ||
|
||
df_m3gnet["e_above_mp_hull"] = df_hull.e_above_mp_hull | ||
|
||
|
||
df_summary = pd.read_csv(f"{ROOT}/data/wbm-steps-summary.csv", comment="#").set_index( | ||
"material_id" | ||
) | ||
|
||
df_m3gnet["e_form_wbm"] = df_summary.e_form | ||
|
||
|
||
# %% | ||
df_m3gnet.hist(bins=200, figsize=(18, 12)) | ||
df_m3gnet.isna().sum() | ||
|
||
|
||
# %% | ||
ax_hull_dist_hist = hist_classify_stable_as_func_of_hull_dist( | ||
formation_energy_targets=df_m3gnet.e_form_wbm, | ||
formation_energy_preds=df_m3gnet.e_form_m3gnet, | ||
e_above_hull_vals=df_m3gnet.e_above_mp_hull, | ||
) | ||
|
||
ax_hull_dist_hist.figure.savefig(f"{ROOT}/data/{today}-m3gnet-wbm-hull-dist-hist.pdf") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# %% | ||
from __future__ import annotations | ||
|
||
import os | ||
import warnings | ||
from datetime import datetime | ||
from typing import Any | ||
|
||
import m3gnet | ||
import numpy as np | ||
import pandas as pd | ||
from m3gnet.models import Relaxer | ||
from pymatgen.core import Structure | ||
|
||
import wandb | ||
from ml_stability import ROOT, as_dict_handler | ||
|
||
|
||
""" | ||
To slurm submit this file, use | ||
```sh | ||
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-100 \ | ||
--time 3:0:0 --job-name m3gnet-wbm-relax --mem 12000 \ | ||
--output ml_stability/m3gnet/slurm_logs/slurm-%A-%a.out \ | ||
--wrap "python ml_stability/m3gnet/slurm_array_m3gnet_relax_wbm.py" | ||
``` | ||
--time 2h is probably enough but missing indices are annoying so best be safe. | ||
Requires M3GNet installation: pip install m3gnet | ||
""" | ||
|
||
__author__ = "Janosh Riebesell" | ||
__date__ = "2022-08-15" | ||
|
||
|
||
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}") | ||
job_id = os.environ.get("SLURM_JOB_ID", "debug") | ||
print(f"{job_id=}") | ||
m3gnet_version = m3gnet.__version__ | ||
print(f"{m3gnet_version=}") | ||
|
||
job_array_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0)) | ||
# set default job array size to 1000 for fast testing | ||
job_array_size = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000)) | ||
print(f"{job_array_id=}") | ||
|
||
today = f"{datetime.now():%Y-%m-%d}" | ||
out_dir = f"{ROOT}/data/{today}-m3gnet-wbm-relax-results" | ||
os.makedirs(out_dir, exist_ok=True) | ||
json_out_path = f"{out_dir}/{job_array_id}.json.gz" | ||
|
||
if os.path.isfile(json_out_path): | ||
raise SystemExit(f"{json_out_path = } already exists, exciting early") | ||
|
||
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen") | ||
warnings.filterwarnings(action="ignore", category=UserWarning, module="tensorflow") | ||
|
||
relax_results: dict[str, dict[str, Any]] = {} | ||
|
||
|
||
# %% | ||
data_path = f"{ROOT}/data/2022-06-26-wbm-cses-and-initial-structures.json.gz" | ||
df_wbm = pd.read_json(data_path).set_index("material_id") | ||
|
||
df_to_relax = np.array_split(df_wbm, job_array_size)[job_array_id] | ||
|
||
run_params = dict( | ||
m3gnet_version=m3gnet_version, | ||
job_id=job_id, | ||
job_array_id=job_array_id, | ||
data_path=data_path, | ||
) | ||
if wandb.run is None: | ||
wandb.login() | ||
wandb.init( | ||
project="m3gnet", # run will be added to this project | ||
name=f"m3gnet-relax-wbm-{job_id}-{job_array_id}", | ||
config=run_params, | ||
) | ||
|
||
|
||
# %% | ||
relaxer = Relaxer() # This loads the default pre-trained M3GNet model | ||
|
||
for material_id, init_struct in df_to_relax.initial_structure.items(): | ||
if material_id in relax_results: | ||
continue | ||
pmg_struct = Structure.from_dict(init_struct) | ||
relax_result = relaxer.relax(pmg_struct) | ||
relax_dict = { | ||
"m3gnet_structure": relax_result["final_structure"], | ||
"m3gnet_trajectory": relax_result["trajectory"].__dict__, | ||
} | ||
# remove non-serializable AseAtoms from trajectory | ||
relax_dict["trajectory"].pop("atoms") | ||
relax_results[material_id] = relax_dict | ||
|
||
|
||
# %% | ||
df_m3gnet = pd.DataFrame(relax_results).T | ||
df_m3gnet.index.name = "material_id" | ||
|
||
|
||
df_m3gnet.to_json(json_out_path, default_handler=as_dict_handler) | ||
|
||
|
||
wandb.log_artifact(json_out_path, type="m3gnet-relaxed-wbm-initial-structures") |