Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache dataset #248

Merged
merged 14 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ tmp/
.traj
.h5
events.out.*
*.schema.json

# Translations
*.mo
Expand Down
32 changes: 29 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ See the [Jax installation instructions](https://github.com/google/jax#installati

In order to train a model, you need to run

```python
```bash
apax train config.yaml
```

We offer some input file templates to get new users started as quickly as possible.
Simply run the following commands and add the appropriate entries in the marked fields

```python
```bash
apax template train # use --full for a template with all input options
```

Expand All @@ -79,7 +79,7 @@ The documentation can convenienty be accessed by running `apax docs`.
There are two ways in which `apax` models can be used for molecular dynamics out of the box.
High performance NVT simulations using JaxMD can be started with the CLI by running

```python
```bash
apax md config.yaml md_config.yaml
```

Expand All @@ -88,6 +88,32 @@ A template command for MD input files is provided as well.
The second way is to use the ASE calculator provided in `apax.md`.


## Input File Auto-Completion

use the following command to generate JSON schemata for training and validation files:

```bash
apax schema
```

If you are using VSCode, you can utilize them to lint and autocomplete your input files by including them in `.vscode/settings.json`

```json
{
"yaml.schemas": {

"/absolute/path/to/apaxtrain.schema.json": [
"train.yaml"
]
,
"/absolute/path/to/apaxmd.schema.json": [
"md.yaml"
]
}
}
```


## Authors
- Moritz René Schäfer
- Nico Segreto
Expand Down
1 change: 1 addition & 0 deletions apax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
jax.config.update("jax_enable_x64", True)
from apax.utils.helpers import setup_ase

Expand Down
6 changes: 3 additions & 3 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import trange

from apax.bal import feature_maps, kernel, selection, transforms
from apax.data.input_pipeline import InMemoryDataset
from apax.data.input_pipeline import OTFInMemoryDataset
from apax.model.builder import ModelBuilder
from apax.model.gmnn import EnergyModel
from apax.train.checkpoints import (
Expand Down Expand Up @@ -46,7 +46,7 @@ def create_feature_fn(
return feature_fn


def compute_features(feature_fn, dataset: InMemoryDataset):
def compute_features(feature_fn, dataset: OTFInMemoryDataset):
"""Compute the features of a dataset."""
features = []
n_data = dataset.n_data
Expand Down Expand Up @@ -85,7 +85,7 @@ def kernel_selection(
is_ensemble = n_models > 1

n_train = len(train_atoms)
dataset = InMemoryDataset(
dataset = OTFInMemoryDataset(
train_atoms + pool_atoms,
cutoff=config.model.r_max,
bs=processing_batch_size,
Expand Down
18 changes: 18 additions & 0 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.metadata
import importlib.resources as pkg_resources
import json
import sys
from pathlib import Path

Expand Down Expand Up @@ -93,6 +94,23 @@ def docs():
typer.launch("https://apax.readthedocs.io/en/latest/")


@app.command()
def schema():
"""
Generating JSON schemata for autocompletion of train/md inputs in VSCode.
"""
console.print("Generating JSON schema")
from apax.config import Config, MDConfig

train_schema = Config.model_json_schema()
md_schema = MDConfig.model_json_schema()
with open("./apaxtrain.schema.json", "w") as f:
f.write(json.dumps(train_schema, indent=2))

with open("./apaxmd.schema.json", "w") as f:
f.write(json.dumps(md_schema, indent=2))


@validate_app.command("train")
def validate_train_config(
config_path: Path = typer.Argument(
Expand Down
5 changes: 4 additions & 1 deletion apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DataConfig(BaseModel, extra="forbid"):

directory: str
experiment: str
ds_type: Literal["cached", "otf"] = "cached"
data_path: Optional[str] = None
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
Expand Down Expand Up @@ -228,8 +229,10 @@ class LossConfig(BaseModel, extra="forbid"):
"""

name: str
loss_type: str = "structures"
loss_type: str = "mse"
weight: NonNegativeFloat = 1.0
atoms_exponent: NonNegativeFloat = 1
parameters: dict = {}


class CallbackConfig(BaseModel, frozen=True, extra="forbid"):
Expand Down
108 changes: 93 additions & 15 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import uuid
from collections import deque
from pathlib import Path
from random import shuffle
from typing import Dict, Iterator

Expand Down Expand Up @@ -44,6 +46,7 @@ def __init__(
n_jit_steps=1,
pre_shuffle=False,
ignore_labels=False,
cache_path=".",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache path maybe better in experiment dir?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is set to the experiment dir in run.py
This is just the default, but we can remove that if you prefer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No its finde.

) -> None:
if pre_shuffle:
shuffle(atoms)
Expand All @@ -68,6 +71,7 @@ def __init__(
self.buffer = deque()
self.batch_size = self.validate_batch_size(bs)
self.n_jit_steps = n_jit_steps
self.file = Path(cache_path) / str(uuid.uuid4())

self.enqueue(min(self.buffer_size, self.n_data))

Expand All @@ -85,7 +89,6 @@ def validate_batch_size(self, batch_size: int) -> int:
f"requested batch size {batch_size} is larger than the number of data"
f" points {self.n_data}. Setting batch size = {self.n_data}"
)
print("Warning: " + msg)
log.warning(msg)
batch_size = self.n_data
return batch_size
Expand Down Expand Up @@ -125,20 +128,6 @@ def enqueue(self, num_elements):
self.buffer.append(data)
self.count += 1

def __iter__(self):
epoch = 0
while epoch < self.n_epochs or len(self.buffer) > 0:
yield self.buffer.popleft()

space = self.buffer_size - len(self.buffer)
if self.count + space > self.n_data:
space = self.n_data - self.count

if self.count >= self.n_data and epoch < self.n_epochs:
epoch += 1
self.count = 0
self.enqueue(space)

def make_signature(self) -> tf.TensorSpec:
input_signature = {}
input_signature["n_atoms"] = tf.TensorSpec((), dtype=tf.int16, name="n_atoms")
Expand Down Expand Up @@ -189,6 +178,89 @@ def init_input(self) -> Dict[str, np.ndarray]:
inputs = jax.tree_map(lambda x: jnp.array(x), inputs)
return inputs, np.array(box)

def __iter__(self):
raise NotImplementedError

def shuffle_and_batch(self):
raise NotImplementedError

def batch(self) -> Iterator[jax.Array]:
raise NotImplementedError

def cleanup(self):
pass


class CachedInMemoryDataset(InMemoryDataset):
def __iter__(self):
while self.count < self.n_data or len(self.buffer) > 0:
yield self.buffer.popleft()

space = self.buffer_size - len(self.buffer)
if self.count + space > self.n_data:
space = self.n_data - self.count
self.enqueue(space)

def shuffle_and_batch(self):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.

Returns
-------
ds :
Iterator that returns inputs and labels of one batch in each step.
"""
ds = (
tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
)
.cache(self.file.as_posix())
.repeat(self.n_epochs)
)

ds = ds.shuffle(
buffer_size=self.buffer_size, reshuffle_each_iteration=True
).batch(batch_size=self.batch_size)
if self.n_jit_steps > 1:
ds = ds.batch(batch_size=self.n_jit_steps)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds

def batch(self) -> Iterator[jax.Array]:
ds = (
tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
)
.cache(self.file.as_posix())
.repeat(self.n_epochs)
)
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds

def cleanup(self):
for p in self.file.parent.glob(f"{self.file.name}.data*"):
p.unlink()

index_file = self.file.parent / f"{self.file.name}.index"
index_file.unlink()


class OTFInMemoryDataset(InMemoryDataset):
def __iter__(self):
epoch = 0
while epoch < self.n_epochs or len(self.buffer) > 0:
yield self.buffer.popleft()

space = self.buffer_size - len(self.buffer)
if self.count + space > self.n_data:
space = self.n_data - self.count

if self.count >= self.n_data and epoch < self.n_epochs:
epoch += 1
self.count = 0
self.enqueue(space)

def shuffle_and_batch(self):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.
Expand Down Expand Up @@ -217,3 +289,9 @@ def batch(self) -> Iterator[jax.Array]:
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds


dataset_dict = {
"cached": CachedInMemoryDataset,
"otf": OTFInMemoryDataset,
}
4 changes: 2 additions & 2 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from matscipy.neighbours import neighbour_list
from tqdm import trange

from apax.data.input_pipeline import InMemoryDataset
from apax.data.input_pipeline import OTFInMemoryDataset
from apax.model import ModelBuilder
from apax.train.checkpoints import check_for_ensemble, restore_parameters
from apax.utils.jax_md_reduced import partition, quantity, space
Expand Down Expand Up @@ -256,7 +256,7 @@ def batch_eval(
"""
if self.model is None:
self.initialize(atoms_list[0])
dataset = InMemoryDataset(
dataset = OTFInMemoryDataset(
atoms_list,
self.model_config.model.r_max,
batch_size,
Expand Down
4 changes: 2 additions & 2 deletions apax/train/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import trange

from apax.config import parse_config
from apax.data.input_pipeline import InMemoryDataset
from apax.data.input_pipeline import OTFInMemoryDataset
from apax.model import ModelBuilder
from apax.train.callbacks import initialize_callbacks
from apax.train.checkpoints import restore_single_parameters
Expand Down Expand Up @@ -122,7 +122,7 @@ def eval_model(config_path, n_test=-1, log_file="eval.log", log_level="error"):
Metrics = initialize_metrics(config.metrics)

atoms_list = load_test_data(config, model_version_path, eval_path, n_test)
test_ds = InMemoryDataset(
test_ds = OTFInMemoryDataset(
atoms_list, config.model.r_max, config.data.valid_batch_size
)

Expand Down
Loading
Loading