Skip to content

Commit

Permalink
Merge pull request #264 from apax-hub/dev
Browse files Browse the repository at this point in the history
Cummulative 0.4.0 Changes
  • Loading branch information
M-R-Schaefer authored Apr 10, 2024
2 parents 97340bc + f43e2f4 commit 687ab99
Show file tree
Hide file tree
Showing 82 changed files with 10,650 additions and 2,405 deletions.
32 changes: 0 additions & 32 deletions .github/workflows/documentation.yaml

This file was deleted.

2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ main.py
tmp/
.npz
.traj
.h5
events.out.*
*.schema.json

# Translations
*.mo
Expand Down
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@ repos:
rev: 24.3.0
hooks:
- id: black
exclude: ^apax/utils/jax_md_reduced/

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
exclude: ^apax/utils/jax_md_reduced/

- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies: [ flake8-isort ]
exclude: ^apax/utils/jax_md_reduced/
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ build:
post_install:
- pip install poetry
- poetry config virtualenvs.create false
- poetry install --with=docs
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install --with docs #poetry install --with=docs

# Build documentation in the docs/ directory with Sphinx
sphinx:
Expand Down
60 changes: 35 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,24 @@ It is based on [JAX](https://jax.readthedocs.io/en/latest/) and uses [JaxMD](htt

## Installation

You can install [Poetry](https://python-poetry.org/) via
Apax is available on PyPI with a CPU version of JAX.

```bash
curl -sSL https://install.python-poetry.org | python3 -
pip install apax
```

Now you can install apax in your project by running
For more detailed instructions, please refer to the [documentation](https://apax.readthedocs.io/en/latest/).

```bash
poetry add git+https://github.com/apax-hub/apax.git
```

As a developer, you can clone the repository and install it via

```bash
git clone https://github.com/apax-hub/apax.git <dest_dir>
cd <dest_dir>
poetry install
```

### CUDA Support
Note that the above only installs the CPU version.
If you want to enable GPU support, please overwrite the jaxlib version:
If you want to enable GPU support (only on Linux), please overwrite the jaxlib version:

```bash
pip install --upgrade pip
```

CUDA 12 installation. Wheels only available on linux.
CUDA 12:
```bash
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

CUDA 11 installation. Wheels only available on linux.
CUDA 11:
```bash
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
Expand All @@ -60,14 +44,14 @@ See the [Jax installation instructions](https://github.com/google/jax#installati

In order to train a model, you need to run

```python
```bash
apax train config.yaml
```

We offer some input file templates to get new users started as quickly as possible.
Simply run the following commands and add the appropriate entries in the marked fields

```python
```bash
apax template train # use --full for a template with all input options
```

Expand All @@ -79,7 +63,7 @@ The documentation can convenienty be accessed by running `apax docs`.
There are two ways in which `apax` models can be used for molecular dynamics out of the box.
High performance NVT simulations using JaxMD can be started with the CLI by running

```python
```bash
apax md config.yaml md_config.yaml
```

Expand All @@ -88,6 +72,32 @@ A template command for MD input files is provided as well.
The second way is to use the ASE calculator provided in `apax.md`.


## Input File Auto-Completion

use the following command to generate JSON schemata for training and validation files:

```bash
apax schema
```

If you are using VSCode, you can utilize them to lint and autocomplete your input files by including them in `.vscode/settings.json`

```json
{
"yaml.schemas": {

"/absolute/path/to/apaxtrain.schema.json": [
"train.yaml"
]
,
"/absolute/path/to/apaxmd.schema.json": [
"md.yaml"
]
}
}
```


## Authors
- Moritz René Schäfer
- Nico Segreto
Expand Down
8 changes: 6 additions & 2 deletions apax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os

from jax.config import config as jax_config
import jax

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
jax_config.update("jax_enable_x64", True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
jax.config.update("jax_enable_x64", True)
from apax.utils.helpers import setup_ase

setup_ase()
80 changes: 68 additions & 12 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,24 @@
import numpy as np
from ase import Atoms
from click import Path
from flax.core.frozen_dict import FrozenDict
from tqdm import trange

from apax.bal import feature_maps, kernel, selection, transforms
from apax.data.initialization import RawDataset
from apax.data.input_pipeline import AtomisticDataset
from apax.data.input_pipeline import OTFInMemoryDataset
from apax.model.builder import ModelBuilder
from apax.model.gmnn import EnergyModel
from apax.train.checkpoints import (
canonicalize_energy_model_parameters,
check_for_ensemble,
restore_parameters,
)
from apax.train.run import initialize_dataset


def create_feature_fn(
model: EnergyModel,
params,
base_feature_map,
params: FrozenDict,
base_feature_map: feature_maps.FeatureTransformation,
feature_transforms=[],
is_ensemble: bool = False,
):
Expand All @@ -33,6 +32,19 @@ def create_feature_fn(
All transformations are applied on the feature function, not on computed features.
Only the final function is jit compiled.
Attributes
----------
model: EnergyModel
Model to be transformed.
params: FrozenDict
Model parameters
base_feature_map: FeatureTransformation
Class that transforms the model into a `FeatureMap`
is_ensemble: bool
Whether or not to apply the ensemble transformation i.e.
an averaging of kernels for model ensembles.
"""
feature_fn = base_feature_map.apply(model)

Expand All @@ -48,14 +60,24 @@ def create_feature_fn(
return feature_fn


def compute_features(feature_fn, dataset: AtomisticDataset):
"""Compute the features of a dataset."""
def compute_features(
feature_fn: feature_maps.FeatureMap, dataset: OTFInMemoryDataset
) -> np.ndarray:
"""Compute the features of a dataset.
Attributes
----------
feature_fn:
Function to compute the features with.
dataset:
Dataset to compute the features for.
"""
features = []
n_data = dataset.n_data
ds = dataset.batch()

pbar = trange(n_data, desc="Computing features", ncols=100, leave=True)
for i, (inputs, _) in enumerate(ds):
for inputs in ds:
g = feature_fn(inputs)
features.append(np.asarray(g))
pbar.update(g.shape[0])
Expand All @@ -74,7 +96,36 @@ def kernel_selection(
feature_transforms: list = [],
selection_batch_size: int = 10,
processing_batch_size: int = 64,
):
) -> list[int]:
"""
Main function to facilitate batch data selection.
Currently only the last layer gradient features and MaxDist selection method
are available.
More can be added as needed as this function is agnostic of the feature
map/selection method internals.
Attributes
----------
model_dir: Union[Path, List[Path]]
Path to the trained model or models which should be used to compute features.
train_atoms: List[Atoms]
List of `ase.Atoms` used to train the models.
pool_atoms: List[Atoms]
List of `ase.Atoms` to select new data from.
base_fm_options:
Settings for the base feature map.
selection_method:
Currently only "max_dist" is supported.
feature_transforms:
Feature transforms to be applied on top of the
base feature map transform.
Examples would include multiplication with or addition of a constant.
selection_batch_size:
Amount of new data points to be selected from `pool_atoms`.
processing_batch_size:
Amount of data points to compute the features for at once.
Does not effect results, just the speed of processing.
"""
selection_fn = {
"max_dist": selection.max_dist_selection,
}[selection_method]
Expand All @@ -87,10 +138,15 @@ def kernel_selection(
is_ensemble = n_models > 1

n_train = len(train_atoms)
dataset = initialize_dataset(
config, RawDataset(atoms_list=train_atoms + pool_atoms), calc_stats=False
dataset = OTFInMemoryDataset(
train_atoms + pool_atoms,
cutoff=config.model.r_max,
bs=processing_batch_size,
n_epochs=1,
ignore_labels=True,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)
dataset.set_batch_size(processing_batch_size)

_, init_box = dataset.init_input()

Expand Down
20 changes: 15 additions & 5 deletions apax/bal/feature_maps.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from typing import Literal, Tuple, Union
from typing import Callable, Literal, Tuple, Union

import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.traverse_util import flatten_dict, unflatten_dict
from pydantic import BaseModel, TypeAdapter

from apax.model.gmnn import EnergyModel

FeatureMap = Callable[[FrozenDict, dict], jax.Array]


class FeatureTransformation(BaseModel):
def apply(self, model: EnergyModel) -> FeatureMap:
return model


def extract_feature_params(params: dict, layer_name: str) -> Tuple[dict, dict]:
"""Separate params into those belonging to a selected layer
Expand All @@ -22,7 +32,7 @@ def extract_feature_params(params: dict, layer_name: str) -> Tuple[dict, dict]:
return feature_layer_params, remaining_params


class LastLayerGradientFeatures(BaseModel, extra="forbid"):
class LastLayerGradientFeatures(FeatureTransformation, extra="forbid"):
"""
Model transfomration which computes the gradient of the output
wrt. the specified layer.
Expand All @@ -32,7 +42,7 @@ class LastLayerGradientFeatures(BaseModel, extra="forbid"):
name: Literal["ll_grad"]
layer_name: str = "dense_2"

def apply(self, model):
def apply(self, model: EnergyModel) -> FeatureMap:
def ll_grad(params, inputs):
ll_params, remaining_params = extract_feature_params(params, self.layer_name)

Expand Down Expand Up @@ -67,12 +77,12 @@ def inner(ll_params):
return ll_grad


class IdentityFeatures(BaseModel, extra="forbid"):
class IdentityFeatures(FeatureTransformation, extra="forbid"):
"""Identity feature map. For debugging purposes"""

name: Literal["identity"]

def apply(self, model):
def apply(self, model: EnergyModel) -> FeatureMap:
return model.apply


Expand Down
7 changes: 7 additions & 0 deletions apax/bal/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ def max_dist_selection(matrix: KernelMatrix, batch_size: int):
https://arxiv.org/pdf/2203.09410.pdf
https://doi.org/10.1039/D2DD00034B
Attributes
----------
matrix: KernelMatrix
Kernel used to compare structures.
batch_size: int
Number of new data points to be selected.
"""
n_train = matrix.n_train

Expand Down
Loading

0 comments on commit 687ab99

Please sign in to comment.