Skip to content

Commit

Permalink
train_alignn.py add wandb tracking and avoid POSCARs on disk (#45)
Browse files Browse the repository at this point in the history
* train_alignn.py add wandb tracking and refactor to not need POSCARs on disk

* remove debug line

* bump elementari

* NERSC tweaks

* fix test_slurm

* git rm paper (git submodule not actually needed)
  • Loading branch information
janosh committed Jul 11, 2023
1 parent c03d741 commit ce6a077
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 90 deletions.
18 changes: 7 additions & 11 deletions matbench_discovery/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def slurm_submit(
job_name: str,
out_dir: str,
time: str,
partition: str,
account: str,
partition: str | None = None,
py_file_path: str | None = None,
slurm_flags: str | Sequence[str] = (),
array: str | None = None,
Expand All @@ -46,10 +46,10 @@ def slurm_submit(
out_dir (str): Directory to write slurm logs. Log file will include slurm job
ID and array task ID.
time (str): 'HH:MM:SS' time limit for the job.
py_file_path (str, optional): Path to the python script to be submitted.
Defaults to the path of the file calling slurm_submit().
account (str): Account to charge for this job.
partition (str, optional): Slurm partition.
account (str, optional): Account to charge for this job.
py_file_path (str, optional): Path to the python script to be submitted.
slurm_flags (str | list[str], optional): Extra slurm CLI flags. Defaults to ().
Examples: ('--nodes 1', '--gpus-per-node 1') or ('--mem', '16G').
array (str, optional): Slurm array specifier. Defaults to None. Example:
Expand All @@ -67,23 +67,19 @@ def slurm_submit(
dict[str, str]: Slurm variables like job ID, array task ID, compute nodes IDs,
submission node ID and total job memory.
"""
if py_file_path is None:
py_file_path = _get_calling_file_path(frame=2)

if "GPU" in partition:
# on Ampere GPU partition, source module CLI and load default Ampere env
# before actual job command
pre_cmd += ". /etc/profile.d/modules.sh; module load rhel8/default-amp;"
py_file_path = py_file_path or _get_calling_file_path(frame=2)

os.makedirs(out_dir, exist_ok=True) # slurm fails if out_dir is missing

cmd = [
*f"sbatch --{partition=} --{account=} --{time=}".replace("'", "").split(),
*f"sbatch --{account=} --{time=}".replace("'", "").split(),
*("--job-name", job_name),
*("--output", f"{out_dir}/slurm-%A{'-%a' if array else ''}.log"),
*(slurm_flags.split() if isinstance(slurm_flags, str) else slurm_flags),
*("--wrap", f"{pre_cmd} python {py_file_path}".strip()),
]
if partition:
cmd += ["--partition", partition]
if array:
cmd += ["--array", array]

Expand Down
60 changes: 0 additions & 60 deletions models/alignn/make_train_data.py

This file was deleted.

174 changes: 163 additions & 11 deletions models/alignn/train_alignn.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,171 @@
"""Train a ALIGNN on target_col of data_path."""
# %%
from __future__ import annotations

import json
import os
from importlib.metadata import version
from typing import Any

import pandas as pd
import torch
import wandb
from alignn.config import TrainingConfig
from alignn.data import StructureDataset, load_graphs
from alignn.train import train_dgl
from pymatgen.core import Structure
from pymatgen.io.jarvis import JarvisAtomsAdaptor
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm import tqdm

from matbench_discovery import DEBUG, today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.slurm import slurm_submit

__author__ = "Philipp Benner, Janosh Riebesell"
__date__ = "2023-06-03"

module_dir = os.path.dirname(__file__)


# %%
model_name = "alignn-mp-e_form"
target_col = "formation_energy_per_atom"
struct_col = "structure"
input_col = "atoms"
id_col = "material_id"
device = "cuda" if torch.cuda.is_available() else "cpu"
job_name = f"train-{model_name}{'-debug' if DEBUG else ''}"


pred_col = "e_form_per_atom_alignn"
with open(f"{module_dir}/alignn-config.json") as file:
config = TrainingConfig(**json.load(file))

config.output_dir = out_dir = os.getenv(
"SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}"
)

slurm_vars = slurm_submit(
job_name=job_name,
# partition="perlmuttter",
account="matgen_g",
time="4:0:0",
out_dir=out_dir,
slurm_flags="--qos regular --constraint gpu --gpus 1",
pre_cmd="module load pytorch/2.0.1;",
)


# %% Load data
df_cse = pd.read_json(DATA_FILES.mp_computed_structure_entries).set_index(id_col)
df_cse[struct_col] = [
Structure.from_dict(cse[struct_col])
for cse in tqdm(df_cse.entry, desc="Structures from dict")
]

# load energies
df_in = pd.read_csv(DATA_FILES.mp_energies).set_index(id_col)
df_in[struct_col] = df_cse[struct_col]
assert target_col in df_in

df_in[input_col] = df_in[struct_col]
df_in[input_col] = [
JarvisAtomsAdaptor.get_atoms(struct)
for struct in tqdm(df_in[struct_col], desc="Converting to JARVIS atoms")
]


# %%
from alignn.train_folder import train_for_folder
run_params = dict(
data_path=DATA_FILES.mp_energies,
**{f"{dep}_version": version(dep) for dep in ("alignn", "numpy", "torch", "dgl")},
model_name=model_name,
target_col=target_col,
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
slurm_vars=slurm_vars,
alignn_config=config.dict(),
)

wandb.init(project="matbench-discovery", name=job_name, config=run_params)


# %%
df_train, df_val = train_test_split(
df_in.head(1000).reset_index()[[id_col, input_col, target_col]],
test_size=0.05,
random_state=42,
)


def df_to_loader(
df: pd.DataFrame,
batch_size: int = 128,
line_graph: bool = True,
pin_memory: bool = False,
shuffle: bool = True,
**kwargs: Any,
) -> DataLoader:
"""Converts a dataframe to a regular PyTorch dataloader for train/val/test.
Args:
df (pd.DataFrame): With id, input and target columns
batch_size (int, optional): Defaults to 128.
line_graph (bool, optional): Whether to train line (True) or atom (False) graph
version of ALIGNN. Defaults to True.
pin_memory (bool, optional): Whether torch DataLoader should pin memory.
Defaults to False.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
**kwargs: Additional arguments to pass to the StructureDataset
Returns:
DataLoader: _description_
"""
graphs = load_graphs(
df, neighbor_strategy=config.neighbor_strategy, use_canonize=config.use_canonize
)
dataset = StructureDataset(
df.reset_index(drop=True),
graphs,
target=target_col,
line_graph=line_graph,
atom_features=config.atom_features,
id_tag=id_col,
**kwargs,
)
collate_fn = getattr(dataset, f"collate{'_line' if line_graph else ''}_graph")

return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn,
pin_memory=pin_memory,
)


train_loader, val_loader = df_to_loader(df_train), df_to_loader(df_val, shuffle=False)


# %%
prepare_batch = train_loader.dataset.prepare_batch
# triggers error in alignn/train.py line 1059 in train_dgl()
# f.write("%s, %6f, %6f\n" % (id, target, out_data))
# TypeError: must be real number, not list
config.write_predictions = False

train_hist = train_dgl(
config,
train_val_test_loaders=[train_loader, val_loader, val_loader, prepare_batch],
)

__author__ = "Philipp Benner"
__date__ = "2023-06-02"
wandb.log(train_hist)
wandb.save(f"{out_dir}/*")


# %%
train_for_folder(
root_dir="data_train",
config_name="alignn-config.json",
keep_data_order=False,
output_dir="data-train-result",
epochs=1000,
file_format="poscar",
df_hist = pd.concat(
{key: pd.DataFrame(train_hist[key]) for key in ("train", "validation")}
)
table = wandb.Table(dataframe=df_hist)
wandb.log({"table": table})
1 change: 0 additions & 1 deletion paper
Submodule paper deleted from d7c7bf
2 changes: 1 addition & 1 deletion site/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"@sveltejs/vite-plugin-svelte": "^2.4.2",
"@typescript-eslint/eslint-plugin": "^5.61.0",
"@typescript-eslint/parser": "^5.61.0",
"elementari": "^0.1.8",
"elementari": "^0.2.2",
"eslint": "^8.44.0",
"eslint-plugin-svelte": "^2.32.2",
"hastscript": "^7.2.0",
Expand Down
2 changes: 1 addition & 1 deletion site/src/lib/ModelCard.svelte
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<script lang="ts">
import { repository } from '$site/package.json'
import Icon from '@iconify/svelte'
import { pretty_num } from 'elementari/labels'
import { pretty_num } from 'elementari'
import { fade, slide } from 'svelte/transition'
import type { ModelData, ModelStatLabel } from '.'
Expand Down
2 changes: 1 addition & 1 deletion site/src/lib/PtableInset.svelte
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<script lang="ts">
import type { ChemicalElement } from 'elementari'
import { pretty_num } from 'elementari/labels'
import { pretty_num } from 'elementari'
export let element: ChemicalElement
export let elem_counts: number[] | Record<string, number>
Expand Down
2 changes: 1 addition & 1 deletion site/src/routes/preprint/+layout.svelte
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<script lang="ts">
import { References } from '$lib'
import cite from '$root/citation.cff'
import { pretty_num } from 'elementari/labels'
import { pretty_num } from 'elementari'
import { references } from './references.yaml'
export let data
Expand Down
10 changes: 7 additions & 3 deletions tests/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

@patch.dict(os.environ, {"SLURM_JOB_ID": "1234"}, clear=True)
@pytest.mark.parametrize("py_file_path", [None, "path/to/file.py"])
def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) -> None:
@pytest.mark.parametrize("partition", [None, "fake-partition"])
def test_slurm_submit(
capsys: CaptureFixture[str], py_file_path: str | None, partition: str | None
) -> None:
job_name = "test_job"
out_dir = "tmp"
time = "0:0:1"
partition = "fake-partition"
account = "fake-account"

func_call = lambda: slurm_submit(
Expand Down Expand Up @@ -47,10 +49,12 @@ def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) ->
assert mock_subprocess_run.call_count == 1

sbatch_cmd = (
f"sbatch --partition={partition} --account={account} --time={time} "
f"sbatch --account={account} --time={time} "
f"--job-name {job_name} --output {out_dir}/slurm-%A.log --foo "
f"--wrap python {py_file_path or __file__}"
).replace(" --", "\n --")
if partition:
sbatch_cmd += f"\n --partition {partition}"
stdout, stderr = capsys.readouterr()
assert sbatch_cmd in stdout
assert stderr == ""
Expand Down

0 comments on commit ce6a077

Please sign in to comment.