From ce6a077b274caddc50e5ef961d165497efe332c7 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 10 Jul 2023 18:14:09 -0700 Subject: [PATCH] `train_alignn.py` add `wandb` tracking and avoid POSCARs on disk (#45) * train_alignn.py add wandb tracking and refactor to not need POSCARs on disk * remove debug line * bump elementari * NERSC tweaks * fix test_slurm * git rm paper (git submodule not actually needed) --- matbench_discovery/slurm.py | 18 +-- models/alignn/make_train_data.py | 60 -------- models/alignn/train_alignn.py | 174 ++++++++++++++++++++++-- paper | 1 - site/package.json | 2 +- site/src/lib/ModelCard.svelte | 2 +- site/src/lib/PtableInset.svelte | 2 +- site/src/routes/preprint/+layout.svelte | 2 +- tests/test_slurm.py | 10 +- 9 files changed, 181 insertions(+), 90 deletions(-) delete mode 100644 models/alignn/make_train_data.py delete mode 160000 paper diff --git a/matbench_discovery/slurm.py b/matbench_discovery/slurm.py index 4e7b9156..6cfa8dbd 100644 --- a/matbench_discovery/slurm.py +++ b/matbench_discovery/slurm.py @@ -28,8 +28,8 @@ def slurm_submit( job_name: str, out_dir: str, time: str, - partition: str, account: str, + partition: str | None = None, py_file_path: str | None = None, slurm_flags: str | Sequence[str] = (), array: str | None = None, @@ -46,10 +46,10 @@ def slurm_submit( out_dir (str): Directory to write slurm logs. Log file will include slurm job ID and array task ID. time (str): 'HH:MM:SS' time limit for the job. - py_file_path (str, optional): Path to the python script to be submitted. Defaults to the path of the file calling slurm_submit(). + account (str): Account to charge for this job. partition (str, optional): Slurm partition. - account (str, optional): Account to charge for this job. + py_file_path (str, optional): Path to the python script to be submitted. slurm_flags (str | list[str], optional): Extra slurm CLI flags. Defaults to (). Examples: ('--nodes 1', '--gpus-per-node 1') or ('--mem', '16G'). array (str, optional): Slurm array specifier. Defaults to None. Example: @@ -67,23 +67,19 @@ def slurm_submit( dict[str, str]: Slurm variables like job ID, array task ID, compute nodes IDs, submission node ID and total job memory. """ - if py_file_path is None: - py_file_path = _get_calling_file_path(frame=2) - - if "GPU" in partition: - # on Ampere GPU partition, source module CLI and load default Ampere env - # before actual job command - pre_cmd += ". /etc/profile.d/modules.sh; module load rhel8/default-amp;" + py_file_path = py_file_path or _get_calling_file_path(frame=2) os.makedirs(out_dir, exist_ok=True) # slurm fails if out_dir is missing cmd = [ - *f"sbatch --{partition=} --{account=} --{time=}".replace("'", "").split(), + *f"sbatch --{account=} --{time=}".replace("'", "").split(), *("--job-name", job_name), *("--output", f"{out_dir}/slurm-%A{'-%a' if array else ''}.log"), *(slurm_flags.split() if isinstance(slurm_flags, str) else slurm_flags), *("--wrap", f"{pre_cmd} python {py_file_path}".strip()), ] + if partition: + cmd += ["--partition", partition] if array: cmd += ["--array", array] diff --git a/models/alignn/make_train_data.py b/models/alignn/make_train_data.py deleted file mode 100644 index b52507e1..00000000 --- a/models/alignn/make_train_data.py +++ /dev/null @@ -1,60 +0,0 @@ -# %% Imports -import os - -import pandas as pd -from pymatgen.core import Structure -from sklearn.model_selection import train_test_split -from tqdm import tqdm, trange - -from matbench_discovery.data import DATA_FILES -from matbench_discovery.structure import perturb_structure - -__author__ = "Philipp Benner" -__date__ = "2023-06-02" - - -# %% -target_col = "formation_energy_per_atom" -input_col = "structure" -id_col = "material_id" -n_perturb = 0 - - -# %% load structures -df_cse = pd.read_json(DATA_FILES.mp_computed_structure_entries).set_index(id_col) -df_cse[input_col] = [ - Structure.from_dict(cse[input_col]) for cse in tqdm(df_cse.entry, disable=None) -] - -# load energies -df = pd.read_csv(DATA_FILES.mp_energies).set_index(id_col) -df[input_col] = df_cse[input_col] -assert target_col in df - - -# %% augment with randomly perturbed structures -df_aug = df.copy() -structs = df_aug.pop(input_col) -for idx in trange(n_perturb, desc="Generating perturbed structures"): - df_aug[input_col] = [perturb_structure(x) for x in structs] - df = pd.concat([df, df_aug.set_index(f"{x}-aug={idx+1}" for x in df_aug.index)]) - -del df_aug - - -# %% export data -X_train, X_test, y_train, y_test = train_test_split( - df[input_col], df[target_col], test_size=0.05, random_state=42 -) - -for samples, targets, label in ((X_train, y_train, "train"), (X_test, y_test, "test")): - out_dir = f"{label}-data" - os.makedirs(out_dir, exist_ok=True) - - targets.to_csv(f"{out_dir}/targets.csv") - - struct: Structure - for mat_id, struct in tqdm( - samples.items(), desc="Saving structures", total=len(samples) - ): - struct.to(f"{out_dir}/{mat_id}.poscar", fmt="POSCAR") diff --git a/models/alignn/train_alignn.py b/models/alignn/train_alignn.py index e6243cd2..98874491 100644 --- a/models/alignn/train_alignn.py +++ b/models/alignn/train_alignn.py @@ -1,19 +1,171 @@ -"""Train a ALIGNN on target_col of data_path.""" +# %% +from __future__ import annotations + +import json +import os +from importlib.metadata import version +from typing import Any + +import pandas as pd +import torch +import wandb +from alignn.config import TrainingConfig +from alignn.data import StructureDataset, load_graphs +from alignn.train import train_dgl +from pymatgen.core import Structure +from pymatgen.io.jarvis import JarvisAtomsAdaptor +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader +from tqdm import tqdm + +from matbench_discovery import DEBUG, today +from matbench_discovery.data import DATA_FILES +from matbench_discovery.slurm import slurm_submit + +__author__ = "Philipp Benner, Janosh Riebesell" +__date__ = "2023-06-03" + +module_dir = os.path.dirname(__file__) + + +# %% +model_name = "alignn-mp-e_form" +target_col = "formation_energy_per_atom" +struct_col = "structure" +input_col = "atoms" +id_col = "material_id" +device = "cuda" if torch.cuda.is_available() else "cpu" +job_name = f"train-{model_name}{'-debug' if DEBUG else ''}" + + +pred_col = "e_form_per_atom_alignn" +with open(f"{module_dir}/alignn-config.json") as file: + config = TrainingConfig(**json.load(file)) + +config.output_dir = out_dir = os.getenv( + "SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}" +) + +slurm_vars = slurm_submit( + job_name=job_name, + # partition="perlmuttter", + account="matgen_g", + time="4:0:0", + out_dir=out_dir, + slurm_flags="--qos regular --constraint gpu --gpus 1", + pre_cmd="module load pytorch/2.0.1;", +) + + +# %% Load data +df_cse = pd.read_json(DATA_FILES.mp_computed_structure_entries).set_index(id_col) +df_cse[struct_col] = [ + Structure.from_dict(cse[struct_col]) + for cse in tqdm(df_cse.entry, desc="Structures from dict") +] + +# load energies +df_in = pd.read_csv(DATA_FILES.mp_energies).set_index(id_col) +df_in[struct_col] = df_cse[struct_col] +assert target_col in df_in + +df_in[input_col] = df_in[struct_col] +df_in[input_col] = [ + JarvisAtomsAdaptor.get_atoms(struct) + for struct in tqdm(df_in[struct_col], desc="Converting to JARVIS atoms") +] # %% -from alignn.train_folder import train_for_folder +run_params = dict( + data_path=DATA_FILES.mp_energies, + **{f"{dep}_version": version(dep) for dep in ("alignn", "numpy", "torch", "dgl")}, + model_name=model_name, + target_col=target_col, + df=dict(shape=str(df_in.shape), columns=", ".join(df_in)), + slurm_vars=slurm_vars, + alignn_config=config.dict(), +) + +wandb.init(project="matbench-discovery", name=job_name, config=run_params) + + +# %% +df_train, df_val = train_test_split( + df_in.head(1000).reset_index()[[id_col, input_col, target_col]], + test_size=0.05, + random_state=42, +) + + +def df_to_loader( + df: pd.DataFrame, + batch_size: int = 128, + line_graph: bool = True, + pin_memory: bool = False, + shuffle: bool = True, + **kwargs: Any, +) -> DataLoader: + """Converts a dataframe to a regular PyTorch dataloader for train/val/test. + + Args: + df (pd.DataFrame): With id, input and target columns + batch_size (int, optional): Defaults to 128. + line_graph (bool, optional): Whether to train line (True) or atom (False) graph + version of ALIGNN. Defaults to True. + pin_memory (bool, optional): Whether torch DataLoader should pin memory. + Defaults to False. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True. + **kwargs: Additional arguments to pass to the StructureDataset + + Returns: + DataLoader: _description_ + """ + graphs = load_graphs( + df, neighbor_strategy=config.neighbor_strategy, use_canonize=config.use_canonize + ) + dataset = StructureDataset( + df.reset_index(drop=True), + graphs, + target=target_col, + line_graph=line_graph, + atom_features=config.atom_features, + id_tag=id_col, + **kwargs, + ) + collate_fn = getattr(dataset, f"collate{'_line' if line_graph else ''}_graph") + + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=collate_fn, + pin_memory=pin_memory, + ) + + +train_loader, val_loader = df_to_loader(df_train), df_to_loader(df_val, shuffle=False) + + +# %% +prepare_batch = train_loader.dataset.prepare_batch +# triggers error in alignn/train.py line 1059 in train_dgl() +# f.write("%s, %6f, %6f\n" % (id, target, out_data)) +# TypeError: must be real number, not list +config.write_predictions = False + +train_hist = train_dgl( + config, + train_val_test_loaders=[train_loader, val_loader, val_loader, prepare_batch], +) -__author__ = "Philipp Benner" -__date__ = "2023-06-02" +wandb.log(train_hist) +wandb.save(f"{out_dir}/*") # %% -train_for_folder( - root_dir="data_train", - config_name="alignn-config.json", - keep_data_order=False, - output_dir="data-train-result", - epochs=1000, - file_format="poscar", +df_hist = pd.concat( + {key: pd.DataFrame(train_hist[key]) for key in ("train", "validation")} ) +table = wandb.Table(dataframe=df_hist) +wandb.log({"table": table}) diff --git a/paper b/paper deleted file mode 160000 index d7c7bf56..00000000 --- a/paper +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d7c7bf563dfa42e77d308bdb87f1e77d551637bb diff --git a/site/package.json b/site/package.json index df7cda54..9fb22f17 100644 --- a/site/package.json +++ b/site/package.json @@ -24,7 +24,7 @@ "@sveltejs/vite-plugin-svelte": "^2.4.2", "@typescript-eslint/eslint-plugin": "^5.61.0", "@typescript-eslint/parser": "^5.61.0", - "elementari": "^0.1.8", + "elementari": "^0.2.2", "eslint": "^8.44.0", "eslint-plugin-svelte": "^2.32.2", "hastscript": "^7.2.0", diff --git a/site/src/lib/ModelCard.svelte b/site/src/lib/ModelCard.svelte index 84734919..fcbd7bee 100644 --- a/site/src/lib/ModelCard.svelte +++ b/site/src/lib/ModelCard.svelte @@ -1,7 +1,7 @@