Skip to content

Commit

Permalink
Diatomics task page (#211)
Browse files Browse the repository at this point in the history
* add /tasks/diatomics page with d3-powered diatomic energy curves

- Add new diatomic task to modeling-tasks.yml
- Create DiatomicCurve Svelte component for interactive potential energy curves
- Add diatomic metrics support in model schema and Figshare article IDs
- /tasks/diatomics page has dynamic model selection and data loading
- a few (too few) unit tests for DiatomicCurve.svelte and /tasks/diatomics

* dynamic model selection on tasks/diatomics page

- error handling and logging in diatomic curve generation
- refactor diatomic data loading and caching in frontend
- model YAML files add diatomic pred_file paths and URLs

* fix site build from mismatched grace model name

* fix tests

* move diatomics page data fetching to server side

- simplify client-side logic with pre-fetched data
- add error state visualization to model selection UI
- add type DiatomicsCurves

* refactor diatomic/metrics.py

- new dataclasses DiatomicCurve(s) for auto-complete and type checking
- use full 3d force arrays in metrics/diatomics/force.py, not just x components at each atomic distance
- rename calc_energy_mae_vs_ref → calc_energy_mae and force_mae_vs_ref → calc_force_mae
- add new metrics like force_conservation and force_mae
- update tests to cover new metrics

* add unit tests for DiatomicCurve and DiatomicCurves classes

* add scripts/evals/diatomic_metrics.py and diatomics.write_metrics_to_yaml

- Rename `calc_diatomic_curve_metrics` to `calc_diatomic_metrics`
- Add `write_metrics_to_yaml` function to save diatomic metrics to model YAML
- Update `DiatomicCurves.from_dict` to filter out curves zero length
- unit tests for write_metrics_to_yaml
  • Loading branch information
janosh authored Feb 22, 2025
1 parent e1d57cc commit 26d3a19
Show file tree
Hide file tree
Showing 33 changed files with 1,505 additions and 418 deletions.
4 changes: 3 additions & 1 deletion matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ def update_yaml_at_path(
# Update the data at the final level
if last not in current:
current[last] = {}
# Replace the entire section to preserve comments
for key, val in current[last].items():
data.setdefault(key, val)
# Replace the entire current[last] section to preserve comments
current[last] = data

# Write back to file
Expand Down
19 changes: 12 additions & 7 deletions matbench_discovery/diatomics.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,25 @@ def calc_diatomic_curve(
elem1 = atom_num_symbol_map.get(z1, z1)
elem2 = atom_num_symbol_map.get(z2, z2)
formula = f"{elem1}-{elem2}"
ef_dict = results.setdefault(formula, {"energies": [], "forces": []})

# Initialize formula dict if not present
if formula not in results:
results[formula] = {"energies": [], "forces": []}
elif len(results[formula].get("energies", [])) == len(distances):
len_e, len_f = len(ef_dict.get("energies", [])), len(ef_dict.get("forces", []))
# skip if we have results for this formula and they match expected length
if len_e == len_f == len(distances):
continue

pbar.set_description(
f"{idx}/{len(pairs)} {formula} diatomic curve with {model_name}"
)

# reset ef_dict in case we had prior results
results[formula] |= {"energies": [], "forces": []}
for atoms in generate_diatomics(elem1, elem2, distances):
results[formula]["energies"] += [calculator.get_potential_energy(atoms)]
results[formula]["forces"] += [calculator.get_forces(atoms).tolist()]
try:
for atoms in generate_diatomics(elem1, elem2, distances):
results[formula]["energies"] += [calculator.get_potential_energy(atoms)]
results[formula]["forces"] += [calculator.get_forces(atoms).tolist()]
except Exception as exc:
print(f"{idx}/{len(pairs)} {formula} failed: {exc}")
continue

return results
11 changes: 6 additions & 5 deletions matbench_discovery/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ class MbdKey(LabelEnum):
energy_grad_norm_max = "energy_grad_norm_max", "Energy Grad Norm Max (eV/Å)"
force_total_variation = "force_total_variation", "Force Total Variation (eV/Å)"
force_jump = "force_jump", "Force Jump (eV/Å)"
energy_mae_vs_ref = "energy_mae_vs_ref", "Energy MAE vs Reference (eV)"
force_mae_vs_ref = "force_mae_vs_ref", "Force MAE vs Reference (eV/Å)"
energy_mae = "energy_mae", "Energy MAE vs Reference (eV)"
force_mae = "force_mae", "Force MAE (eV/Å)"
force_conservation = "force_conservation", "Force Conservation (eV/Å)"


@unique
Expand Down Expand Up @@ -330,9 +331,9 @@ class Model(Files, base_dir=f"{ROOT}/models"):
eqv2_s_dens = auto(), "eqV2/eqV2-s-dens-mp.yml"
eqv2_m = auto(), "eqV2/eqV2-m-omat-salex-mp.yml"

grace_2l_mptrj = auto(), "grace/grace-2L-mptrj.yml"
grace_2l_oam = auto(), "grace/grace-2L-oam.yml"
grace_1l_oam = auto(), "grace/grace-1L-oam.yml"
grace_2l_mptrj = auto(), "grace/grace-2l-mptrj.yml"
grace_2l_oam = auto(), "grace/grace-2l-oam.yml"
grace_1l_oam = auto(), "grace/grace-1l-oam.yml"

# --- Model Combos
# # CHGNet-relaxed structures fed into MEGNet for formation energy prediction
Expand Down
258 changes: 176 additions & 82 deletions matbench_discovery/metrics/diatomics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,129 +6,223 @@
https://huggingface.co/spaces/atomind/mlip-arena, respectively.
"""

import inspect
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from dataclasses import dataclass, field
from typing import Any, Self

# ruff: noqa: F401 (don't flag convenience imports above)
from matbench_discovery.enums import MbdKey
from matbench_discovery.metrics.diatomics import energy, force
import numpy as np

from matbench_discovery.data import update_yaml_at_path
from matbench_discovery.enums import MbdKey, Model
from matbench_discovery.metrics.diatomics import energy, force # noqa: F401
from matbench_discovery.metrics.diatomics.energy import (
calc_conservation_deviation,
calc_curve_diff_auc,
calc_energy_diff_flips,
calc_energy_grad_norm_max,
calc_energy_jump,
calc_energy_mae_vs_ref,
calc_energy_mae,
calc_second_deriv_smoothness,
calc_tortuosity,
)
from matbench_discovery.metrics.diatomics.force import (
calc_conservation_deviation,
calc_force_flips,
calc_force_jump,
calc_force_mae_vs_ref,
calc_force_mae,
calc_force_total_variation,
)

# Type alias for a curve represented as a tuple of x and y values
DiatomicCurve = tuple[Sequence[float], Sequence[float]]
# Type alias for a dictionary mapping element symbols to curves
DiatomicCurves = Mapping[str, DiatomicCurve]

@dataclass
class DiatomicCurve:
"""Energies and forces for a single diatomic molecule at multiple distances."""

distances: np.ndarray # shape (n_distances,)
energies: np.ndarray # shape (n_distances,)
forces: np.ndarray # shape (n_distances, n_atoms, 3)

def __post_init__(self) -> None:
"""Convert inputs to numpy arrays."""
self.energies = np.asarray(self.energies)
self.forces = np.asarray(self.forces)
self.distances = np.asarray(self.distances)


@dataclass
class DiatomicCurves:
"""Container for diatomic potential energy curves and forces of multiple
element pairs.
Attributes:
distances (np.ndarray): Interatomic distances in Å.
homo_nuclear (dict[str, DiatomicCurve]): Map of element pairs
(e.g. "H-H") to their DiatomicCurve (energies and forces).
hetero_nuclear (dict[str, DiatomicCurve] | None): Optional map of element pairs
(e.g. "H-He") to their DiatomicCurve.
"""

def calc_diatomic_curve_metrics(
ref_curves: DiatomicCurves,
distances: np.ndarray # shape (n_distances,)
homo_nuclear: dict[str, DiatomicCurve]
hetero_nuclear: dict[str, DiatomicCurve] = field(default_factory=dict)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
"""Create DiatomicCurves from a dictionary loaded from JSON."""
dists = data["distances"] = np.asarray(data["distances"])
for key in {"homo-nuclear", "hetero-nuclear"} & set(data):
data[key.replace("-", "_")] = {
formula: DiatomicCurve(**dct, distances=dists)
for formula, dct in data.pop(key).items()
if len(dct["energies"]) > 0
}
return cls(**data)


def calc_diatomic_metrics(
ref_curves: DiatomicCurves | None,
pred_curves: DiatomicCurves,
pred_force_curves: DiatomicCurves | None = None,
metrics: dict[str, dict[str, Any]] | None = None,
) -> dict[str, dict[str, float]]:
"""Calculate diatomic curve metrics comparing predicted curves to reference curves.
Args:
ref_curves (DiatomicCurves): Reference energy curves for each element.
ref_curves (DiatomicCurves | None): Reference energy curves for each element.
If None, only metrics that don't require reference data will be calculated.
pred_curves (DiatomicCurves): Predicted energy curves for each element.
pred_force_curves (DiatomicCurves | None): Predicted force curves for each
element. Required for force-based metrics.
metrics (dict[str, dict[str, Any]] | None): Dictionary mapping metric names to
metrics (dict[str, dict[str, Any]] | None): Map of metric names to
dictionaries of keyword arguments for each metric function. If None, uses
all metrics with default parameters. To use a subset of metrics, provide
a dictionary with those metric names as keys and their keyword arguments
as values. Empty dictionaries will use default parameters.
Returns:
dict[str, dict[str, float]]: Dictionary mapping element symbols to metrics dict
with keys being the metric names and values being the metric values.
dict[str, dict[str, float]]: Map of element symbols to metric dicts with keys
being the metric names and values being the metric values.
"""
results: dict[str, dict[str, float]] = {}

# Map metric keys to their functions
metric_functions: dict[str, Callable[..., float]] = {
# Energy metrics that need both curves
MbdKey.norm_auc: energy.calc_curve_diff_auc,
MbdKey.energy_mae_vs_ref: energy.calc_energy_mae_vs_ref,
# Energy metrics that need only predicted curve
MbdKey.smoothness: energy.calc_second_deriv_smoothness,
MbdKey.tortuosity: energy.calc_tortuosity,
MbdKey.conservation: energy.calc_conservation_deviation,
MbdKey.energy_diff_flips: energy.calc_energy_diff_flips,
MbdKey.energy_grad_norm_max: energy.calc_energy_grad_norm_max,
MbdKey.energy_jump: energy.calc_energy_jump,
# Force metrics that need both curves
MbdKey.force_mae_vs_ref: force.calc_force_mae_vs_ref,
# Force metrics that need only predicted curve
MbdKey.force_flips: force.calc_force_flips,
MbdKey.force_total_variation: force.calc_force_total_variation,
MbdKey.force_jump: force.calc_force_jump,
}
if unknown_metrics := set(metrics or {}) - set(MbdKey):
raise ValueError(f"{unknown_metrics=}. Valid metrics=")

# If no metrics specified, use all metrics with default parameters
results: dict[str, dict[str, float]] = {}
metrics = (metrics or {}).copy()

if unknown_metrics := set(metrics) - set(metric_functions):
raise ValueError(
f"{unknown_metrics=}. Valid metrics={', '.join(metric_functions)}"
)

for key in metric_functions:
# Initialize empty kwargs for each metric if not provided
for key in MbdKey:
metrics.setdefault(key, {})

# Remove force-based metrics if no force curves provided
if pred_force_curves is None:
metrics = {
name: kwargs
for name, kwargs in metrics.items()
if not name.startswith("force_")
}
for elem_symbol, pred_data in pred_curves.homo_nuclear.items():
elem_metrics: dict[str, float] = {}
distances = pred_curves.distances

# Skip reference-requiring metrics if no reference curves provided
if ref_curves and (ref_data := ref_curves.homo_nuclear.get(elem_symbol)):
if not np.array_equal(distances, ref_curves.distances):
raise ValueError(
"Reference and predicted distances must be the same. If goal is "
"to interpolate predicted curves to reference distances, do so "
"before passing to calc_diatomic_metrics."
)

for elem_symbol, ref_curve in ref_curves.items():
if elem_symbol not in pred_curves:
continue
# Energy metrics that need both curves
if MbdKey.norm_auc in metrics:
elem_metrics[MbdKey.norm_auc] = calc_curve_diff_auc(
distances,
ref_data.energies,
distances,
pred_data.energies,
**metrics[MbdKey.norm_auc],
)

pred_curve = pred_curves[elem_symbol]
elem_metrics: dict[str, float] = {}
if MbdKey.energy_mae in metrics:
elem_metrics[MbdKey.energy_mae] = calc_energy_mae(
distances,
ref_data.energies,
distances,
pred_data.energies,
**metrics[MbdKey.energy_mae],
)

for name, func_kwargs in metrics.items():
metric_func = metric_functions[name]
param_set = set(inspect.signature(metric_func).parameters)
needs_ref_curve = any("_ref" in param for param in param_set)
is_force_metric = name.startswith("force_")

# Handle force metrics
if is_force_metric:
if pred_force_curves is None or elem_symbol not in pred_force_curves:
continue
curve_to_use = pred_force_curves[elem_symbol]
else: # energy metrics
curve_to_use = pred_curve

# Call metric function with appropriate arguments
if needs_ref_curve:
elem_metrics[name] = metric_func(
*ref_curve, *curve_to_use, **func_kwargs
if MbdKey.force_mae in metrics:
elem_metrics[MbdKey.force_mae] = calc_force_mae(
distances,
ref_data.forces,
distances,
pred_data.forces,
**metrics[MbdKey.force_mae],
)
else:
elem_metrics[name] = metric_func(*curve_to_use, **func_kwargs)

# Energy metrics that need only predicted curve
if MbdKey.smoothness in metrics:
elem_metrics[MbdKey.smoothness] = calc_second_deriv_smoothness(
distances, pred_data.energies, **metrics[MbdKey.smoothness]
)

if MbdKey.tortuosity in metrics:
elem_metrics[MbdKey.tortuosity] = calc_tortuosity(
distances, pred_data.energies, **metrics[MbdKey.tortuosity]
)

if MbdKey.energy_diff_flips in metrics:
elem_metrics[MbdKey.energy_diff_flips] = calc_energy_diff_flips(
distances, pred_data.energies, **metrics[MbdKey.energy_diff_flips]
)

if MbdKey.energy_grad_norm_max in metrics:
elem_metrics[MbdKey.energy_grad_norm_max] = calc_energy_grad_norm_max(
distances, pred_data.energies, **metrics[MbdKey.energy_grad_norm_max]
)

if MbdKey.energy_jump in metrics:
elem_metrics[MbdKey.energy_jump] = calc_energy_jump(
distances, pred_data.energies, **metrics[MbdKey.energy_jump]
)

# Force metrics that need only predicted curve
if MbdKey.conservation in metrics:
elem_metrics[MbdKey.conservation] = calc_conservation_deviation(
distances,
pred_data.energies,
pred_data.forces,
**metrics[MbdKey.conservation],
)

if MbdKey.force_flips in metrics:
elem_metrics[MbdKey.force_flips] = calc_force_flips(
distances, pred_data.forces, **metrics[MbdKey.force_flips]
)

if MbdKey.force_total_variation in metrics:
elem_metrics[MbdKey.force_total_variation] = calc_force_total_variation(
distances, pred_data.forces, **metrics[MbdKey.force_total_variation]
)

if MbdKey.force_jump in metrics:
elem_metrics[MbdKey.force_jump] = calc_force_jump(
distances, pred_data.forces, **metrics[MbdKey.force_jump]
)

results[elem_symbol] = elem_metrics

return results


def write_metrics_to_yaml(model: Model, metrics: dict[str, dict[str, float]]) -> None:
"""Write diatomic metrics to model YAML file.
Args:
model (Model): Model to write metrics for.
metrics (dict[str, dict[str, float]]): Map of element symbols to dicts of
metric values.
"""
if not metrics:
print(f"No valid metrics for {model.name}, skipping")
return

# Calculate mean metrics across all elements
mean_metrics = {
str(metric): float(
f"{np.mean([elem_metrics[metric] for elem_metrics in metrics.values()]):.4}"
)
for metric in next(iter(metrics.values()))
}

update_yaml_at_path(model.yaml_path, "metrics.diatomics", mean_metrics)
print(f"Wrote metrics to {model.yaml_path}")
Loading

0 comments on commit 26d3a19

Please sign in to comment.