Skip to content

Commit

Permalink
Merge pull request #297 from apax-hub/dev
Browse files Browse the repository at this point in the history
v0.5.0
  • Loading branch information
M-R-Schaefer authored Jul 24, 2024
2 parents 4793205 + fa17ee2 commit b268f30
Show file tree
Hide file tree
Showing 59 changed files with 3,165 additions and 1,894 deletions.
5 changes: 0 additions & 5 deletions .flake8

This file was deleted.

26 changes: 8 additions & 18 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,12 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 24.4.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.2
hooks:
- id: black
exclude: ^apax/utils/jax_md_reduced/

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
exclude: ^apax/utils/jax_md_reduced/

- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies: [ flake8-isort ]
exclude: ^apax/utils/jax_md_reduced/
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
3 changes: 3 additions & 0 deletions apax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings

import jax

Expand All @@ -8,3 +9,5 @@
from apax.utils.helpers import setup_ase

setup_ase()

warnings.filterwarnings("ignore", message=".*os.fork()*")
2 changes: 1 addition & 1 deletion apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def kernel_selection(
n_train = len(train_atoms)
dataset = OTFInMemoryDataset(
train_atoms + pool_atoms,
cutoff=config.model.r_max,
cutoff=config.model.basis.r_max,
bs=processing_batch_size,
n_epochs=1,
ignore_labels=True,
Expand Down
79 changes: 70 additions & 9 deletions apax/bal/feature_maps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, Literal, Tuple, Union

import jax
import jax.ad_checkpoint
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.traverse_util import flatten_dict, unflatten_dict
Expand Down Expand Up @@ -39,7 +40,7 @@ class LastLayerGradientFeatures(FeatureTransformation, extra="forbid"):
https://arxiv.org/pdf/2203.09410
"""

name: Literal["ll_grad"]
name: Literal["ll_grad"] = "ll_grad"
layer_name: str = "dense_2"

def apply(self, model: EnergyModel) -> FeatureMap:
Expand All @@ -58,19 +59,79 @@ def inner(ll_params):
inputs["box"],
inputs["offsets"],
)
return model.apply(full_params, R, Z, idx, box, offsets)
out = model.apply(full_params, R, Z, idx, box, offsets)
# take mean in case of shallow ensemble
# no effect for single model
out = jnp.mean(out)
return out

g_ll = jax.grad(inner)(ll_params)
g_ll = unflatten_dict(g_ll)

g_ll = jax.tree_map(lambda arr: jnp.mean(arr, axis=-1, keepdims=True), g_ll)
g_flat = jax.tree_map(lambda arr: jnp.reshape(arr, (-1,)), g_ll)
(gw, gb), _ = jax.tree_util.tree_flatten(g_flat)
(gb, gw), _ = jax.tree_util.tree_flatten(g_flat)

g = [gw, gb]
g = jnp.concatenate(g)

return g

return ll_grad


class LastLayerForceFeatures(FeatureTransformation, extra="forbid"):
"""
Model transfomration which computes the gradient of the output
wrt. the specified layer.
"""

name: Literal["ll_force_feat"] = "ll_force_feat"
layer_name: str = "dense_2"
return_raw: bool = True

def apply(self, model: EnergyModel) -> FeatureMap:
def ll_grad(params, inputs):
ll_params, remaining_params = extract_feature_params(params, self.layer_name)

bias_factor = 0.1
weight_factor = jnp.sqrt(1 / gw.shape[-1])
g_scaled = [weight_factor * gw, bias_factor * gb]
energy_fn = lambda *inputs: jnp.mean(model.apply(*inputs))
force_fn = jax.grad(energy_fn, 1)

g = jnp.concatenate(g_scaled)
def inner(ll_params):
ll_params.update(remaining_params)
full_params = unflatten_dict(ll_params)

R, Z, idx, box, offsets = (
inputs["positions"],
inputs["numbers"],
inputs["idx"],
inputs["box"],
inputs["offsets"],
)
out = force_fn(full_params, R, Z, idx, box, offsets)
return out

ll_params = jax.tree_map(
lambda arr: jnp.mean(arr, axis=-1, keepdims=True), ll_params
)
g_ll = jax.jacfwd(inner)(ll_params)
g_ll = unflatten_dict(g_ll)

# shapes:
# b: n_atoms, 3, 1
# w: n_atoms, 3, n_features, 1

if self.return_raw:
(gb, gw), _ = jax.tree_util.tree_flatten(g_ll)

# g: n_atoms, 3, n_features
g = gw[:, :, :, 0]
else:
g_flat = jax.tree_map(
lambda arr: jnp.reshape(jnp.sum(jnp.sum(arr, 0), 0), (-1,)), g_ll
)
(gb, gw), _ = jax.tree_util.tree_flatten(g_flat)
g = [gw, gb]
g = jnp.concatenate(g)

return g

Expand All @@ -87,5 +148,5 @@ def apply(self, model: EnergyModel) -> FeatureMap:


FeatureMapOptions = TypeAdapter(
Union[LastLayerGradientFeatures, IdentityFeatures]
Union[LastLayerGradientFeatures, LastLayerForceFeatures, IdentityFeatures]
).validate_python
2 changes: 1 addition & 1 deletion apax/cli/apax_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def visualize_model(
"Training configuration file to be visualized. A CO molecule is taken as"
" sample input."
),
)
),
):
"""
Visualize a model based on a configuration file.
Expand Down
39 changes: 25 additions & 14 deletions apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
n_epochs: <NUMBER OF EPOCHS>
seed: 1
patience: null
n_models: 1
n_jitted_steps: 1
data_parallel: True
weight_average: null

data:
directory: models/
Expand All @@ -15,8 +15,11 @@ data:
#train_data_path: <PATH>
#val_data_path: <PATH>
#test_data_path: <PATH>
dataset:
processing: cached
shuffle_buffer_size: 1000

additional_properties_info: {}
ds_type: cached

n_train: 1000
n_valid: 100
Expand All @@ -30,20 +33,27 @@ data:
scale_method: "per_element_force_rms_scale"
scale_options: {}

shuffle_buffer_size: 1000

pos_unit: Ang
energy_unit: eV

model:
n_basis: 7
basis:
name: gaussian
n_basis: 7
r_max: 6.0
r_min: 0.5

ensemble: null
# if you would like to train model ensembles, this can be achieved with
# the following example.
# ensemble:
# kind: full
# n_members: N

n_radial: 5
n_contr: -1
n_contr: 8
nn: [512, 512]

r_max: 6.0
r_min: 0.5

calc_stress: true
use_zbl: false

Expand Down Expand Up @@ -73,16 +83,17 @@ metrics:
- mse

optimizer:
opt_name: adam
opt_kwargs: {}
name: adam
kwargs: {}
emb_lr: 0.03
nn_lr: 0.03
scale_lr: 0.001
shift_lr: 0.05
zbl_lr: 0.001
transition_begin: 0
sam_rho: 0.0

schedule:
name: linear
transition_begin: 0
end_value: 1e-6
callbacks:
- name: csv

Expand Down
45 changes: 45 additions & 0 deletions apax/config/lr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Literal

from pydantic import BaseModel, NonNegativeFloat


class LRSchedule(BaseModel, frozen=True, extra="forbid"):
name: str


class LinearLR(LRSchedule, frozen=True, extra="forbid"):
"""
Configuration of the optimizer.
Learning rates of 0 will freeze the respective parameters.
Parameters
----------
name : str, default = "adam"
transition_begin: int = 0
Number of steps after which to start decreasing
end_value: NonNegativeFloat = 1e-6
Final LR at the end of training.
"""

name: Literal["linear"]
transition_begin: int = 0
end_value: NonNegativeFloat = 1e-6


class CyclicCosineLR(LRSchedule, frozen=True, extra="forbid"):
"""
Configuration of the optimizer.
Learning rates of 0 will freeze the respective parameters.
Parameters
----------
period: int = 20
Length of a cycle in epochs.
decay_factor: NonNegativeFloat = 1.0
Factor by which to decrease the LR after each cycle.
1.0 means no decrease.
"""

name: Literal["cyclic_cosine"]
period: int = 20
decay_factor: NonNegativeFloat = 1.0
4 changes: 4 additions & 0 deletions apax/config/md_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ class NVEOptions(Integrator, extra="forbid"):
----------
name : Literal["nve"]
Name of the ensemble.
init_temperature : PositiveFloat, default = 298.15
Initialisation temperature in Kelvin (K).
"""

name: Literal["nve"]
init_temperature: PositiveFloat = 298.15 # K


class NVTOptions(Integrator, extra="forbid"):
Expand Down
Loading

0 comments on commit b268f30

Please sign in to comment.