Skip to content

Commit

Permalink
add module doc string with installation instructions to train_mace.py
Browse files Browse the repository at this point in the history
fix pyright possibly unbound variable errors
  • Loading branch information
janosh committed Mar 21, 2024
1 parent 71a5edc commit 29eecf2
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 50 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
rev: v0.3.3
hooks:
- id: ruff
args: [--fix]
Expand All @@ -30,7 +30,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
rev: v1.9.0
hooks:
- id: mypy
additional_dependencies: [types-pyyaml, types-requests]
Expand All @@ -56,7 +56,7 @@ repos:
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.0.0-beta.1
rev: v9.0.0-beta.2
hooks:
- id: eslint
types: [file]
Expand All @@ -80,7 +80,7 @@ repos:
- id: check-github-actions

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.352
rev: v1.1.355
hooks:
- id: pyright
args: [--level, error]
5 changes: 3 additions & 2 deletions data/mp/build_phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@


# %%
df_mp["our_mp_e_form"] = [
e_form_us = "e_form_us"
df_mp[e_form_us] = [
get_e_form_per_atom(mp_computed_entries[mp_id]) for mp_id in df_mp.index
]


# make sure get_form_energy_per_atom() reproduces MP formation energies
ax = pymatviz.density_scatter(df_mp[Key.form_energy], df_mp["our_mp_e_form"])
ax = pymatviz.density_scatter(df_mp[Key.form_energy], df_mp[e_form_us])
ax.set(
title="MP Formation Energy Comparison",
xlabel="MP Formation Energy (eV/atom)",
Expand Down
2 changes: 1 addition & 1 deletion data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

"""
NOTE MaterialsProject2020Compatibility takes structural information into account when
correcting energies (only applies to certain oxides and sulfides). Always use
correcting energies (for certain oxides and sulfides). Always use
ComputedStructureEntry, not ComputedEntry when applying corrections.
"""

Expand Down
23 changes: 11 additions & 12 deletions data/wbm/compile_wbm_test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def increment_wbm_material_id(wbm_id: str) -> str:
assert df_wbm.index[-1] == "wbm-5-23308"

df_wbm[Key.init_struct] = df_wbm.pop("org")
df_wbm["final_structure"] = df_wbm.pop("opt")
assert list(df_wbm.columns) == [Key.init_struct, "final_structure"]
df_wbm[Key.final_struct] = df_wbm.pop("opt")
assert list(df_wbm.columns) == [Key.init_struct, Key.final_struct]


# %% download WBM ComputedStructureEntries from
Expand Down Expand Up @@ -247,7 +247,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
]

df_wbm["composition_from_final_struct"] = [
Structure.from_dict(struct).composition for struct in tqdm(df_wbm.final_structure)
Structure.from_dict(struct).composition for struct in tqdm(df_wbm[Key.final_struct])
]

# all but 1 composition matches between CSE and final structure
Expand Down Expand Up @@ -499,7 +499,9 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
for mat_id, cse in df_wbm[Key.cse].items():
assert mat_id == cse["entry_id"], f"{mat_id} != {cse['entry_id']}"

df_wbm["cse"] = [ComputedStructureEntry.from_dict(dct) for dct in tqdm(df_wbm[Key.cse])]
df_wbm[Key.cse] = [
ComputedStructureEntry.from_dict(dct) for dct in tqdm(df_wbm[Key.cse])
]
# raw WBM ComputedStructureEntries have no energy corrections applied:
assert all(cse.uncorrected_energy == cse.energy for cse in df_wbm.cse)
# summary and CSE n_sites match
Expand Down Expand Up @@ -548,7 +550,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
# takes ~20 min at 200 it/s for 250k entries in WBM
assert Key.each_true not in df_summary

for mat_id, cse in tqdm(df_wbm.cse.items(), total=len(df_wbm)):
for mat_id, cse in tqdm(df_wbm[Key.cse].items(), total=len(df_wbm)):
assert mat_id == cse.entry_id, f"{mat_id=} != {cse.entry_id=}"
assert cse.entry_id in df_summary.index, f"{cse.entry_id=} not in df_summary"

Expand All @@ -562,7 +564,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
assert sum(df_wbm.index != df_summary.index) == 0

for row in tqdm(df_wbm.itertuples(), total=len(df_wbm), desc="ML energies to CSEs"):
mat_id, cse, formula = row.Index, row.cse, row.formula_from_cse
mat_id, cse, formula = row.Index, row[Key.cse], row.formula_from_cse
assert mat_id == cse.entry_id, f"{mat_id=} != {cse.entry_id=}"
assert mat_id in df_summary.index, f"{mat_id=} not in df_summary"

Expand Down Expand Up @@ -665,12 +667,9 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:

# %% only here to load data for later inspection
if False:
wbm_summary_path = f"{WBM_DIR}/2022-10-19-wbm-summary.csv.gz"
df_summary = pd.read_csv(wbm_summary_path).set_index(Key.mat_id)
df_wbm = pd.read_json(
f"{WBM_DIR}/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
).set_index(Key.mat_id)
df_summary = pd.read_csv(DATA_FILES.wbm_summary).set_index(Key.mat_id)
df_wbm = pd.read_json(DATA_FILES.wbm_cses_plus_init_structs).set_index(Key.mat_id)

df_wbm["cse"] = [
df_wbm[Key.cse] = [
ComputedStructureEntry.from_dict(dct) for dct in tqdm(df_wbm[Key.cse])
]
1 change: 1 addition & 0 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def hist_classified_stable_vs_hull_dist(
clf_col, value_name = "classified", "count"

df_plot = pd.DataFrame()
each_true_pos = each_true_neg = each_false_neg = each_false_pos = None

for facet, df_group in (
df.groupby(kwargs["facet_col"]) if "facet_col" in kwargs else [(None, df)]
Expand Down
2 changes: 1 addition & 1 deletion models/alignn_ff/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This effort was aborted for the following reasons:
1. **Training difficulties**: ALIGNN-FF proved to be very resource-hungry. [12 GB of MPtrj training data](https://figshare.com/articles/dataset/23713842) turned into 600 GB of ALIGNN graph data. This forces small batch size even on nodes with large GPU memory, which slowed down training.
1. **Ineffectiveness of fine-tuning**: Efforts to fine-tune the ALIGNN-FF WT10 model on the CHGNet data suffered high initial loss, even worse than the untrained model, indicating significant dataset incompatibility.

The decision to abort adding ALIGNN FF to Matbench Discovery v1 was made after weeks of work due to ongoing technical challenges and resource limitations. See the [PR discussion](https://github.com/janosh/matbench-discovery/pull/47) for further details.
The decision to abort testing ALIGNN FF was made after weeks of work due to ongoing technical challenges and resource limitations. See the [PR discussion](https://github.com/janosh/matbench-discovery/pull/47) for further details.

## Fine-tuning

Expand Down
13 changes: 4 additions & 9 deletions models/mace/readme.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## MACE formation energy predictions on WBM test set

The original MACE submission used the 2M parameter checkpoint [`2023-08-14-mace-yuan-trained-mptrj-04.model`](https://figshare.com/ndownloader/files/42374049) trained by Yuan Chiang on the [MPtrj dataset](https://figshare.com/articles/dataset/23713842).
We initially tested the `2023-07-14-mace-universal-2-big-128-6.model` checkpoint trained on the much smaller [original M3GNet training set](https://figshare.com/articles/dataset/MPF_2021_2_8/19470599) which we received directly from Ilyes Batatia. MPtrj-trained MACE performed better and was used for the Matbench Discovery v1 submission.
We initially tested the `2023-07-14-mace-universal-2-big-128-6.model` checkpoint trained on the much smaller [original M3GNet training set](https://figshare.com/articles/dataset/MPF_2021_2_8/19470599) which we received directly from Ilyes Batatia. MPtrj-trained MACE performed better and was used for the Matbench Discovery submission.

In late October (received 2023-10-29), Philipp Benner trained a much larger 16M parameter MACE for over 100 epochs in MPtrj which achieved an (at the time SOTA) F1 score of 0.64 and DAF of 3.13.

Expand All @@ -21,12 +21,7 @@ MACE relaxed each test set structure until the maximum force in the training set

#### Training

- `loss="uip"`
- `energy_weight=1`
- `forces_weight=1`
- `stress_weight=0.01`
- `r_max=6.0`
- `lr=0.005`
- `batch_size=10`
See the module doc string in `train_mace.py` for how to install MACE for multi-GPU training.
A single-GPU training script that works with the current [MACE PyPI release](https://pypi.org/project/mace-torch) (v0.3.4 as of 2024-03-21) could be provided if there's interest.

We used conditional loss weighting. We did _not_ use MACE's newest attention block feature which in our testing performed significantly worse than `RealAgnosticResidualInteractionBlock`.
Our training used conditional loss weighting. We did _not_ use MACE's newest attention block feature which in our testing performed significantly worse than `RealAgnosticResidualInteractionBlock`.
35 changes: 23 additions & 12 deletions models/mace/train_mace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
"""
This script requires installing the as-yet unmerged multi-GPU branch
in the MACE repo.
pip install git+https://github.com/ACEsuit/mace@multi-GPU
Plan is to merge it into main and then release to PyPI. At that point,
the install command will be:
pip install mace-torch
If you want to fine-tune an existing MACE checkpoint rather than train a
model from scratch, install the foundations branch instead which has an interface
just for that.
pip install git+https://github.com/ACEsuit/mace@foundations
"""

from __future__ import annotations

import ast
Expand Down Expand Up @@ -36,9 +50,6 @@
__date__ = "2023-09-18"


# This script requires installing MACE.
# pip install git+https://github.com/ACEsuit/mace

module_dir = os.path.dirname(__file__)

slurm_vars = slurm_submit(
Expand Down Expand Up @@ -77,8 +88,8 @@ def main(**kwargs: Any) -> None:
if args.distributed:
try:
distr_env = DistributedEnvironment()
except Exception as e:
print(f"Error specifying environment for distributed training: {e}")
except Exception as exc:
print(f"Error specifying environment for distributed training: {exc}")
return
world_size = distr_env.world_size
local_rank = distr_env.local_rank
Expand Down Expand Up @@ -122,10 +133,10 @@ def main(**kwargs: Any) -> None:

# Data preparation
if args.train_file.endswith(".xyz"):
if args.valid_file is not None:
assert args.valid_file.endswith(
".xyz"
), "valid_file if given must be same format as train_file"
if args.valid_file is not None and not args.valid_file.endswith(".xyz"):
raise RuntimeError(
f"valid_file must be .xyz if train_file is .xyz, got {args.valid_file}"
)
config_type_weights = get_config_type_weights(args.config_type_weights)
collections, atomic_energies_dict = get_dataset_from_xyz(
train_path=args.train_file,
Expand All @@ -150,7 +161,7 @@ def main(**kwargs: Any) -> None:
f"{len(collections.valid)}, tests=[{test_config_lens}]"
)
elif args.train_file.endswith(".h5"):
atomic_energies_dict = None
atomic_energies_dict = collections = None
else:
raise RuntimeError(
f"train_file must be either .xyz or .h5, got {args.train_file}"
Expand Down Expand Up @@ -485,8 +496,8 @@ def main(**kwargs: Any) -> None:
f"{args.swa_forces_weight}, learning rate : {args.swa_lr}"
)
if args.loss == "forces_only":
print("Can not select swa with forces only loss.")
elif args.loss == "virials":
raise RuntimeError("Can not select SWA with forces-only loss.")
if args.loss == "virials":
loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ running-models = [
# when attempting PyPI publish
# "aviary@git+https://github.com/CompRhys/aviary",
"alignn",
"chgnet",
"chgnet>=0.3.0",
"jarvis-tools",
"m3gnet",
"mace-torch",
Expand Down Expand Up @@ -93,6 +93,7 @@ ignore = [
"D205", # blank-line-after-summary
"DTZ005",
"E731", # lambda-assignment
"EM101",
"EM102",
"FBT001",
"FBT002",
Expand Down
12 changes: 6 additions & 6 deletions scripts/update_wandb_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@


# %%
df = pd.DataFrame([run.config | dict(run.summary) for run in runs])
df[["display_name", "id"]] = [(run.display_name, run.id) for run in runs]
df_runs = pd.DataFrame([run.config | dict(run.summary) for run in runs])
df_runs[["display_name", "id"]] = [(run.display_name, run.id) for run in runs]


# %%
df.isna().sum()
df_runs.isna().sum()


# %% --- Update run metadata ---
Expand All @@ -41,9 +41,9 @@
"mace-wbm-IS2RE-debug-", "mace-wbm-IS2RE-"
)

for x in (Task.IS2RE, "ES2RE"):
if x in run.display_name:
new_config["task_type"] = x
for key in (Task.IS2RE, Task.RS2RE):
if key in run.display_name:
new_config["task_type"] = key

if "SLURM_JOB_ID" in new_config:
new_config["slurm_job_id"] = new_config.pop("SLURM_JOB_ID")
Expand Down
4 changes: 2 additions & 2 deletions site/src/routes/preprint/references.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ references:
URL: https://www.nature.com/articles/s41467-021-23339-x
volume: '12'

- id: mok_direction-based_2022
- id: mok_directionbased_2022
accessed:
- year: 2022
month: 10
Expand All @@ -1438,7 +1438,7 @@ references:
given: Jongseung
- family: Back
given: Seoin
citation-key: mok_direction-based_2022
citation-key: mok_directionbased_2022
DOI: 10.26434/chemrxiv-2022-dp58c
genre: preprint
issued:
Expand Down

0 comments on commit 29eecf2

Please sign in to comment.