Skip to content

Commit

Permalink
add ml_stability/m3gnet/slurm_array_m3gnet_relax_wbm.py
Browse files Browse the repository at this point in the history
mv ml_stability/m3gnet/{m3gnet_relax_wbm->join_and_plot_m3gnet_relax_results}.py
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 1c3cc92 commit 1e96458
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 14 deletions.
21 changes: 7 additions & 14 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,17 @@

# cache
__pycache__
.mypy
.db_cache

# Fireworks config
/FW_config.yaml

# Atomate workflow output
fw_logs
block*
launcher*

# datasets
*.json.gz
*.json.bz2
*.csv.bz2

# data files
*.pkl*

# checkpoint files of trained models
pretrained
pretrained/

# Weights and Biases logs
wandb/

# slurm logs
slurm-*out
2 changes: 2 additions & 0 deletions ml_stability/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
from os.path import dirname
from typing import Any, Generator, Sequence
Expand Down
105 changes: 105 additions & 0 deletions ml_stability/m3gnet/join_and_plot_m3gnet_relax_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# %%
from __future__ import annotations

import gzip
import io
import pickle
from datetime import datetime
from glob import glob
from urllib.request import urlopen

import pandas as pd
from diel_frontier.patched_phase_diagram import load_ppd
from pymatgen.analysis.phase_diagram import PDEntry
from pymatgen.core import Structure
from tqdm import tqdm

from ml_stability import ROOT
from ml_stability.plots.plot_funcs import hist_classify_stable_as_func_of_hull_dist


today = f"{datetime.now():%Y-%m-%d}"


# %%
glob_pattern = "2022-08-16-m3gnet-wbm-relax-results/*.json.gz"
file_paths = glob(f"{ROOT}/data/{glob_pattern}")
print(f"Found {len(file_paths):,} files for {glob_pattern = }")


dfs: dict[str, pd.DataFrame] = {}
# 2022-08-16 tried multiprocessing.Pool() to load files in parallel but was somehow
# slower than serial loading
for file_path in tqdm(file_paths):
if file_path in dfs:
continue
try:
dfs[file_path] = pd.read_json(file_path)
except (ValueError, FileNotFoundError):
# pandas v1.5+ correctly raises FileNotFoundError, below raises ValueError
continue


# %%
df_m3gnet = pd.concat(dfs.values())
df_m3gnet.index.name = "material_id"
if any(df_m3gnet.index.str.contains("_")):
df_m3gnet.index = df_m3gnet.index.str.replace("_", "-")

df_m3gnet = df_m3gnet.rename(
columns=dict(final_structure="m3gnet_structure", trajectory="m3gnet_trajectory")
)

df_m3gnet["m3gnet_energy"] = df_m3gnet.trajectory.map(lambda x: x["energies"][-1][0])


# %%
out_file = f"{today}-m3gnet-wbm-relax-results.json.gz"
df_m3gnet.reset_index().to_json(f"{ROOT}/data/{out_file}")


# %%
# 2022-01-25-ppd-mp+wbm.pkl.gz (235 MB)
ppd_pickle_url = "https://figshare.com/ndownloader/files/36669624"
zipped_file = urlopen(ppd_pickle_url)
ppd_mp_wbm = pickle.load(io.BytesIO(gzip.decompress(zipped_file.read())))

ppd_mp_wbm = load_ppd("ppd-mp+wbm-2022-01-25.pkl.gz")


df_m3gnet["m3gnet_structure"] = df_m3gnet.m3gnet_structure.map(Structure.from_dict)
df_m3gnet["pd_entry"] = [
PDEntry(row.m3gnet_structure.composition, row.m3gnet_energy)
for row in df_m3gnet.itertuples()
]
df_m3gnet["e_form_m3gnet"] = df_m3gnet.pd_entry.map(ppd_mp_wbm.get_form_energy_per_atom)


# %%
df_hull = pd.read_csv(
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
).set_index("material_id")

df_m3gnet["e_above_mp_hull"] = df_hull.e_above_mp_hull


df_summary = pd.read_csv(f"{ROOT}/data/wbm-steps-summary.csv", comment="#").set_index(
"material_id"
)

df_m3gnet["e_form_wbm"] = df_summary.e_form


# %%
df_m3gnet.hist(bins=200, figsize=(18, 12))
df_m3gnet.isna().sum()


# %%
ax_hull_dist_hist = hist_classify_stable_as_func_of_hull_dist(
formation_energy_targets=df_m3gnet.e_form_wbm,
formation_energy_preds=df_m3gnet.e_form_m3gnet,
e_above_hull_vals=df_m3gnet.e_above_mp_hull,
)

ax_hull_dist_hist.figure.savefig(f"{ROOT}/data/{today}-m3gnet-wbm-hull-dist-hist.pdf")
109 changes: 109 additions & 0 deletions ml_stability/m3gnet/slurm_array_m3gnet_relax_wbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# %%
from __future__ import annotations

import os
import warnings
from datetime import datetime
from typing import Any

import m3gnet
import numpy as np
import pandas as pd
from m3gnet.models import Relaxer
from pymatgen.core import Structure

import wandb
from ml_stability import ROOT, as_dict_handler


"""
To slurm submit this file, use
```sh
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-100 \
--time 3:0:0 --job-name m3gnet-wbm-relax --mem 12000 \
--output ml_stability/m3gnet/slurm_logs/slurm-%A-%a.out \
--wrap "python ml_stability/m3gnet/slurm_array_m3gnet_relax_wbm.py"
```
--time 2h is probably enough but missing indices are annoying so best be safe.
Requires M3GNet installation: pip install m3gnet
"""

__author__ = "Janosh Riebesell"
__date__ = "2022-08-15"


print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
job_id = os.environ.get("SLURM_JOB_ID", "debug")
print(f"{job_id=}")
m3gnet_version = m3gnet.__version__
print(f"{m3gnet_version=}")

job_array_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
# set default job array size to 1000 for fast testing
job_array_size = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000))
print(f"{job_array_id=}")

today = f"{datetime.now():%Y-%m-%d}"
out_dir = f"{ROOT}/data/{today}-m3gnet-wbm-relax-results"
os.makedirs(out_dir, exist_ok=True)
json_out_path = f"{out_dir}/{job_array_id}.json.gz"

if os.path.isfile(json_out_path):
raise SystemExit(f"{json_out_path = } already exists, exciting early")

warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")
warnings.filterwarnings(action="ignore", category=UserWarning, module="tensorflow")

relax_results: dict[str, dict[str, Any]] = {}


# %%
data_path = f"{ROOT}/data/2022-06-26-wbm-cses-and-initial-structures.json.gz"
df_wbm = pd.read_json(data_path).set_index("material_id")

df_to_relax = np.array_split(df_wbm, job_array_size)[job_array_id]

run_params = dict(
m3gnet_version=m3gnet_version,
job_id=job_id,
job_array_id=job_array_id,
data_path=data_path,
)
if wandb.run is None:
wandb.login()
wandb.init(
project="m3gnet", # run will be added to this project
name=f"m3gnet-relax-wbm-{job_id}-{job_array_id}",
config=run_params,
)


# %%
relaxer = Relaxer() # This loads the default pre-trained M3GNet model

for material_id, init_struct in df_to_relax.initial_structure.items():
if material_id in relax_results:
continue
pmg_struct = Structure.from_dict(init_struct)
relax_result = relaxer.relax(pmg_struct)
relax_dict = {
"m3gnet_structure": relax_result["final_structure"],
"m3gnet_trajectory": relax_result["trajectory"].__dict__,
}
# remove non-serializable AseAtoms from trajectory
relax_dict["trajectory"].pop("atoms")
relax_results[material_id] = relax_dict


# %%
df_m3gnet = pd.DataFrame(relax_results).T
df_m3gnet.index.name = "material_id"


df_m3gnet.to_json(json_out_path, default_handler=as_dict_handler)


wandb.log_artifact(json_out_path, type="m3gnet-relaxed-wbm-initial-structures")

0 comments on commit 1e96458

Please sign in to comment.