Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

full ensemble fix, feature model fix #332

Merged
merged 14 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions apax/config/md_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os

# from types import UnionType
from typing import Literal, Union

import yaml
from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt
from typing_extensions import Annotated

from apax.utils.helpers import APAX_PROPERTIES


class ConstantTempSchedule(BaseModel, extra="forbid"):
"""Constant temperature schedule.
Expand Down Expand Up @@ -234,6 +234,14 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
extra_capacity : int, default = 0
| JaxMD allocates a maximal number of neighbors. This argument lets you add
| additional capacity to avoid recompilation. The default is usually fine.

dynamics_checks: list[DynamicsCheck]
| List of termination criteria. Currently energy and force uncertainty
| are available
properties: list[str]
| Whitelist of properties to be saved in the trajectory.
| This does not effect what the model will calculate, e.g..
| an ensemble will still calculate uncertainties.
initial_structure : str, required
| Path to the starting structure of the simulation.
sim_dir : str, default = "."
Expand Down Expand Up @@ -266,6 +274,8 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):

dynamics_checks: list[DynamicsCheck] = []

properties: list[str] = APAX_PROPERTIES

initial_structure: str
load_momenta: bool = False
sim_dir: str = "."
Expand Down
10 changes: 10 additions & 0 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,18 @@ def make_ensemble(model):
def ensemble(positions, Z, idx, box, offsets):
results = model(positions, Z, idx, box, offsets)
uncertainty = {k + "_uncertainty": jnp.std(v, axis=0) for k, v in results.items()}
ensemble = {k + "_ensemble": v for k, v in results.items()}
results = {k: jnp.mean(v, axis=0) for k, v in results.items()}
if "forces_ensemble" in ensemble.keys():
ensemble["forces_ensemble"] = jnp.transpose(
ensemble["forces_ensemble"], (1, 2, 0)
)
if "forces_ensemble" in ensemble.keys():
ensemble["stress_ensemble"] = jnp.transpose(
ensemble["forces_ensemble"], (1, 2, 0)
)
results.update(uncertainty)
results.update(ensemble)

return results

Expand Down
26 changes: 20 additions & 6 deletions apax/md/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,29 @@
from ase.calculators.singlepoint import SinglePointCalculator

from apax.md.sim_utils import System
from apax.utils.helpers import APAX_PROPERTIES
from apax.utils.jax_md_reduced import space

log = logging.getLogger(__name__)


class TrajHandler:
def __init__(self) -> None:
self.system: System
self.sampling_rate: int
self.buffer_size: int
self.traj_path: Path
self.time_step: float
def __init__(
self,
system: System,
sampling_rate: int,
buffer_size: int,
traj_path: Path,
time_step: float = 0.5,
properties: list[str] = APAX_PROPERTIES,
) -> None:
self.atomic_numbers = system.atomic_numbers
self.box = system.box
self.fractional = np.any(self.box > 1e-6)
self.sampling_rate = sampling_rate
self.traj_path = traj_path
self.time_step = time_step
self.properties = properties

def step(self, state_and_energy, transform=None):
pass
Expand Down Expand Up @@ -53,6 +64,7 @@ def atoms_from_state(self, state, predictions, nbr_kwargs):
atoms.pbc = np.diag(atoms.cell.array) > 1e-6
predictions = {k: np.array(v) for k, v in predictions.items()}
predictions["energy"] = predictions["energy"].item()
predictions = {k: v for k, v in predictions.items() if k in self.properties}
atoms.calc = SinglePointCalculator(atoms, **predictions)
return atoms

Expand All @@ -65,13 +77,15 @@ def __init__(
buffer_size: int,
traj_path: Path,
time_step: float = 0.5,
properties: list[str] = [],
) -> None:
self.atomic_numbers = system.atomic_numbers
self.box = system.box
self.fractional = np.any(self.box > 1e-6)
self.sampling_rate = sampling_rate
self.traj_path = traj_path
self.time_step = time_step
self.properties = properties
self.db = znh5md.IO(
self.traj_path, timestep=self.time_step, store="time", save_units=False
)
Expand Down
5 changes: 3 additions & 2 deletions apax/md/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ def run_sim(
n_inner: int,
extra_capacity: int,
rng_key: int,
traj_handler: TrajHandler,
load_momenta: bool = False,
restart: bool = True,
checkpoint_interval: int = 50_000,
traj_handler: TrajHandler = TrajHandler(),
dynamics_checks: list[DynamicsCheckBase] = [],
disable_pbar: bool = False,
):
Expand Down Expand Up @@ -520,6 +520,7 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"):
md_config.buffer_size,
traj_path,
md_config.ensemble.dt,
properties=md_config.properties,
)
# TODO implement correct chunking

Expand All @@ -531,10 +532,10 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"):
n_inner=md_config.n_inner,
extra_capacity=md_config.extra_capacity,
load_momenta=md_config.load_momenta,
traj_handler=traj_handler,
rng_key=jax.random.PRNGKey(md_config.seed),
restart=md_config.restart,
checkpoint_interval=md_config.checkpoint_interval,
sim_dir=sim_dir,
traj_handler=traj_handler,
dynamics_checks=dynamics_checks,
)
10 changes: 7 additions & 3 deletions apax/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def __call__(
perturbation,
)

gm = self.descriptor(dr_vec, Z, idx)
features = jax.vmap(self.readout)(gm)
features = self.descriptor(dr_vec, Z, idx)

if self.readout:
features = jax.vmap(self.readout)(features)
PythonFZ marked this conversation as resolved.
Show resolved Hide resolved

if self.mask_atoms:
features = mask_by_atom(features, Z)
Expand Down Expand Up @@ -268,7 +270,9 @@ def __call__(

prediction["forces"] = forces_mean
prediction["forces_uncertainty"] = jnp.sqrt(forces_variance)
prediction["forces_ensemble"] = forces_ens

forces_ens = jnp.transpose(forces_ens, (1, 2, 0))
prediction["forces_ensemble"] = forces_ens # n_atoms x 3 x n_members

else:
forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets)
Expand Down
22 changes: 13 additions & 9 deletions apax/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,26 @@

import yaml

APAX_PROPERTIES = [
"energy",
"forces",
"stress",
"forces_uncertainty",
"energy_uncertainty",
"stress_uncertainty",
"energy_ensemble",
"forces_ensemble",
"stress_ensemble",
]


def setup_ase():
"""Add uncertainty keys to ASE all properties.
from https://github.com/zincware/IPSuite/blob/main/ipsuite/utils/helpers.py#L10
"""
from ase.calculators.calculator import all_properties

additional_keys = [
"forces_uncertainty",
"energy_uncertainty",
"stress_uncertainty",
"energy_ensemble",
"forces_ensemble",
]

for val in additional_keys:
for val in APAX_PROPERTIES:
if val not in all_properties:
all_properties.append(val)
PythonFZ marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
13 changes: 9 additions & 4 deletions tests/integration_tests/md/md_config_threshold.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
ensemble:
name: nvt
dt: 0.1 # fs time step
dt: 0.2 # fs time step
PythonFZ marked this conversation as resolved.
Show resolved Hide resolved
temperature_schedule:
name: piecewise
T0: 5 # K
T0: 50 # K
values: [100, 200, 1000]
steps: [10, 20, 30]

duration: 100 # fs
duration: 500 # fs
n_inner: 1
sampling_rate: 1
checkpoint_interval: 2
restart: True
dynamics_checks:
- name: forces_uncertainty
threshold: 1.0
threshold: 0.01
properties:
- energy
- forces
- energy_uncertainty
- forces_ensemble
12 changes: 11 additions & 1 deletion tests/integration_tests/md/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset):
}
model_config_dict = load_config_and_run_training(model_confg_path, data_config_mods)

md_confg_path = TEST_PATH / "md_config.yaml"
md_confg_path = TEST_PATH / "md_config_threshold.yaml"

with open(md_confg_path.as_posix(), "r") as stream:
md_config_dict = yaml.safe_load(stream)
Expand All @@ -214,3 +214,13 @@ def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset):

traj = znh5md.IO(md_config.sim_dir + "/" + md_config.traj_name)[:]
assert len(traj) < 1000 # num steps

results_keys = list(traj[0].calc.results.keys())

assert "energy" in results_keys
assert "forces" in results_keys
assert "energy_uncertainty" in results_keys
assert "forces_ensemble" in results_keys

assert "energy_ensemble" not in results_keys
assert "forces_uncertainty" not in results_keys
Loading