Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PBP dataset performance enhancements #361

Merged
merged 30 commits into from
Nov 25, 2024
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9e41172
reworked pbp datset to use a side thread
M-R-Schaefer Nov 3, 2024
70d8fd8
reduced memory allocations in BatchProcessor
M-R-Schaefer Nov 3, 2024
31e3ccf
switched to vesin
M-R-Schaefer Nov 3, 2024
2bd520f
test
M-R-Schaefer Nov 4, 2024
16169a6
fixed thread not ending after training is done
M-R-Schaefer Nov 4, 2024
1b9dc9a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 4, 2024
ed42a70
switched NL computation in preprocessing and ASE to vesin
M-R-Schaefer Nov 17, 2024
924de56
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
b5e9cb2
remoed barrier wait
M-R-Schaefer Nov 17, 2024
f576771
Merge branch 'threading' of https://github.com/apax-hub/apax into thr…
M-R-Schaefer Nov 17, 2024
71fb826
Merge branch 'main' into threading
M-R-Schaefer Nov 17, 2024
bdf809a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
b1071fb
added vesin
M-R-Schaefer Nov 17, 2024
6bc8b73
fixed ase clac test
M-R-Schaefer Nov 17, 2024
7617acf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
02acad6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
6315eaa
updated config
M-R-Schaefer Nov 17, 2024
77cbc4d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
91bb8a1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
42ee7f8
Merge branch 'main' into threading
M-R-Schaefer Nov 19, 2024
7caad5d
removed matscipy dependency, poetry update
M-R-Schaefer Nov 19, 2024
8510b05
Merge branch 'threading' of https://github.com/apax-hub/apax into thr…
M-R-Schaefer Nov 19, 2024
5e20e7b
fixed imports, docstring, NL padding
M-R-Schaefer Nov 19, 2024
0c720b7
updated ase calc docstring to vesin
M-R-Schaefer Nov 19, 2024
ecbb4f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2024
1519188
added stress to batch processor
M-R-Schaefer Nov 20, 2024
010dc91
Merge branch 'threading' of https://github.com/apax-hub/apax into thr…
M-R-Schaefer Nov 20, 2024
00388f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2024
ca5fbd7
Merge branch 'main' into threading
PythonFZ Nov 22, 2024
d5747ce
Merge branch 'main' into threading
M-R-Schaefer Nov 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
@@ -76,14 +76,16 @@ class PBPDatset(DatasetConfig, extra="forbid"):
----------
num_workers : int
| Number of batches to be processed in parallel.
reset_every : int
| Number of epochs before reinitializing the ProcessPoolExcecutor.
| Avoids memory leaks.
atom_padding : int
| Next nearest integer to which to pad per-atom arrays (positions, forces, ...).
nl_padding: int
| Next nearest integer to which to pad neighborlists.
"""

processing: Literal["pbp"] = "pbp"
num_workers: PositiveInt = 10
reset_every: PositiveInt = 10
atom_padding: PositiveInt = 10
nl_padding: PositiveInt = 2000


class DataConfig(BaseModel, extra="forbid"):
220 changes: 132 additions & 88 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import multiprocessing
import time
import uuid
from collections import deque
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from pathlib import Path
from queue import Queue
from random import shuffle
from threading import Event
from typing import Dict, Iterator, Optional

import jax
@@ -358,87 +361,101 @@ def next_power_of_two(x):
return 1 << (int(x) - 1).bit_length()


def round_up_to_multiple(value, multiple):
"""
Rounds up the given integer `value` to the next multiple of `multiple`.
Parameters:
- value (int): The integer to round up.
- multiple (int): The multiple to round up to.
Returns:
- int: The rounded-up value.
"""
return int(np.ceil(value / multiple) * multiple)


class BatchProcessor:
def __init__(self, cutoff, forces=True, stress=False) -> None:
def __init__(
self, cutoff, atom_padding: int, nl_padding: int, forces=True, stress=False
) -> None:
self.cutoff = cutoff
self.atom_padding = atom_padding
self.nl_padding = nl_padding

self.forces = forces
self.stress = stress

def __call__(self, samples: list[dict]):
n_samples = len(samples)
max_atoms = np.max([inp[0]["n_atoms"] for inp in samples])
max_atoms = round_up_to_multiple(max_atoms, self.atom_padding)

inputs = {
"numbers": [],
"n_atoms": [],
"positions": [],
"box": [],
"idx": [],
"offsets": [],
"numbers": np.zeros((n_samples, max_atoms), dtype=np.int16),
"n_atoms": np.zeros(n_samples, dtype=np.int16),
"positions": np.zeros((n_samples, max_atoms, 3), dtype=np.float64),
"box": np.zeros((n_samples, 3, 3), dtype=np.float32),
}

labels = {
"energy": [],
"energy": np.zeros(n_samples, dtype=np.float64),
}

if self.forces:
labels["forces"] = []
labels["forces"] = np.zeros((n_samples, max_atoms, 3), dtype=np.float64)
if self.stress:
labels["stress"] = []
labels["stress"] = np.zeros((n_samples, 3, 3), dtype=np.float64)

for sample in samples:
inp, lab = sample
idxs = []
offsets = []
for i, (inp, lab) in enumerate(samples):
inputs["numbers"][i, : inp["n_atoms"]] = inp["numbers"]
inputs["n_atoms"][i] = inp["n_atoms"]
inputs["positions"][i, : inp["n_atoms"]] = inp["positions"]
inputs["box"][i] = inp["box"]

inputs["numbers"].append(inp["numbers"])
inputs["n_atoms"].append(inp["n_atoms"])
inputs["positions"].append(inp["positions"])
inputs["box"].append(inp["box"])
idx, offsets = compute_nl(inp["positions"], inp["box"], self.cutoff)
inputs["idx"].append(idx)
inputs["offsets"].append(offsets)
idx, offset = compute_nl(inp["positions"], inp["box"], self.cutoff)
idxs.append(idx)
offsets.append(offset)

labels["energy"].append(lab["energy"])
labels["energy"][i] = lab["energy"]
if self.forces:
labels["forces"].append(lab["forces"])
labels["forces"][i, : inp["n_atoms"]] = lab["forces"]
if self.stress:
labels["stress"].append(lab["stress"])

max_atoms = np.max(inputs["n_atoms"])
max_nbrs = np.max([idx.shape[1] for idx in inputs["idx"]])
labels["stress"][i] = lab["stress"]

max_atoms = next_power_of_two(max_atoms)
max_nbrs = next_power_of_two(max_nbrs)
max_nbrs = np.max([idx.shape[1] for idx in idxs])
max_nbrs = round_up_to_multiple(max_nbrs, self.nl_padding)

for i in range(len(inputs["n_atoms"])):
inputs["idx"][i], inputs["offsets"][i] = pad_nl(
inputs["idx"][i], inputs["offsets"][i], max_nbrs
)

zeros_to_add = max_atoms - inputs["numbers"][i].shape[0]
inputs["positions"][i] = np.pad(
inputs["positions"][i], ((0, zeros_to_add), (0, 0)), "constant"
)
inputs["numbers"][i] = np.pad(
inputs["numbers"][i], (0, zeros_to_add), "constant"
).astype(np.int16)
inputs["idx"] = np.zeros((n_samples, 2, max_nbrs), dtype=np.int16)
inputs["offsets"] = np.zeros((n_samples, max_nbrs, 3), dtype=np.float64)

if "forces" in labels:
labels["forces"][i] = np.pad(
labels["forces"][i], ((0, zeros_to_add), (0, 0)), "constant"
)
for i, (idx, offset) in enumerate(zip(idxs, offsets)):
inputs["idx"][i, :, : idx.shape[1]] = idx
inputs["offsets"][i, : offset.shape[0], :] = offset

inputs = {k: np.array(v) for k, v in inputs.items()}
labels = {k: np.array(v) for k, v in labels.items()}
return inputs, labels


class PerBatchPaddedDataset(InMemoryDataset):
"""Dataset which pads everything (atoms, neighbors)
to the next larges power of two.
This limits the compute wasted due to padding at the (negligible)
cost of some recompilations.
The NL is computed on-the-fly in parallel for `num_workers` of batches.
"""Dataset with padding that leverages multiprocessing and optimized buffering.
Per-atom and per-neighbor arrays are padded to the next multiple of a user specified integer.
This limits the compute wasted due to padding at the (negligible) cost of some recompilations.
Since the padding occurs on a per-batch basis, it is the most performant option for datasets with significantly differently sized systems (e.g. MaterialsProject, SPICE).
Further, the neighborlist is computed on-the-fly in parallel on a side thread.
Does not use tf.data.
Most performant option for datasets with significantly differently sized systems
(e.g. MaterialsProject, SPICE).
Attributes
----------
num_workers : int
Number of processes to use for preprocessing batches.
atom_padding : int
Pad extensive arrays (positions, etc.) to next multiple of this integer.
nl_padding : int
Pad neighborlist arrays to next multiple of this integer.
"""

def __init__(
@@ -449,18 +466,13 @@ def __init__(
n_epochs,
n_jit_steps=1,
num_workers: Optional[int] = None,
reset_every: int = 10,
atom_padding: int = 10,
nl_padding: int = 2000,
pos_unit: str = "Ang",
energy_unit: str = "eV",
pre_shuffle=False,
) -> None:
self.cutoff = cutoff

if n_jit_steps > 1:
raise NotImplementedError(
"PerBatchPaddedDataset is not yet compatible with multi step jit"
)

self.n_jit_steps = n_jit_steps
self.n_epochs = n_epochs
self.n_data = len(atoms_list)
@@ -470,13 +482,12 @@ def __init__(
if num_workers:
self.num_workers = num_workers
else:
self.num_workers = multiprocessing.cpu_count()
self.buffer_size = num_workers * 2
self.batch_size = bs
self.num_workers = multiprocessing.cpu_count() - 1

self.sample_atoms = atoms_list[0]
self.inputs = atoms_to_inputs(atoms_list, pos_unit)

# Transform atoms into inputs and labels
self.inputs = atoms_to_inputs(atoms_list, pos_unit)
self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit)
label_keys = self.labels.keys()

@@ -488,61 +499,85 @@ def __init__(

forces = "forces" in label_keys
stress = "stress" in label_keys
self.prepare_batch = BatchProcessor(cutoff, forces, stress)
self.prepare_batch = BatchProcessor(
cutoff, atom_padding, nl_padding, forces, stress
)

self.count = 0
self.reset_every = reset_every

self.max_count = self.n_epochs * self.steps_per_epoch()
self.buffer = deque()

self.buffer_size = min(600, self.steps_per_epoch())
self.buffer = Queue(maxsize=self.buffer_size)

self.process_pool = ProcessPoolExecutor(self.num_workers)
self.thread_pool = ThreadPoolExecutor(1) # Single thread for buffering batches
self.epoch_finished = False
self.enqueue_future = None
self.needs_data = Event()

def enqueue_batches(self):
"""Function to enqueue batches on a side thread."""
while self.count < self.steps_per_epoch() * self.n_epochs:
self.needs_data.wait()
if self.epoch_finished:
break
num_batches = min(
self.buffer_size - self.buffer.qsize(),
self.steps_per_epoch() - self.count,
)
if num_batches > 0:
self.enqueue(num_batches)
self.needs_data.clear() # Reset event

def enqueue(self, num_batches):
start = self.count * self.batch_size

# Split data into chunks and submit tasks to the process pool
dataset_chunks = [
self.data[start + self.batch_size * i : start + self.batch_size * (i + 1)]
for i in range(0, num_batches)
for i in range(num_batches)
]

# Using submit and as_completed for faster batch retrieval
futures = [
self.process_pool.submit(self.prepare_batch, chunk)
for chunk in dataset_chunks
]
for batch in self.process_pool.map(self.prepare_batch, dataset_chunks):
self.buffer.append(batch)
for future in as_completed(futures):
batch = future.result()
self.buffer.put(batch)

self.count += num_batches

def __iter__(self):
for n in range(self.n_epochs):
self.count = 0
self.buffer = deque()

# reinitialize PPE from time to time to avoid memory leak
if n % self.reset_every == 0:
self.process_pool = ProcessPoolExecutor(self.num_workers)
self.buffer.queue.clear() # Reset buffer
self.epoch_finished = False

if self.should_shuffle:
shuffle(self.data)

self.enqueue(min(self.buffer_size, self.n_data // self.batch_size))
# Start pre-filling the buffer
self.enqueue_future = self.thread_pool.submit(self.enqueue_batches)

for i in range(self.steps_per_epoch()):
batch = self.buffer.popleft()
yield batch

current_buffer_len = len(self.buffer)
space = self.buffer_size - current_buffer_len
if self.buffer.qsize() < (self.buffer_size * 0.75):
self.needs_data.set() # Trigger buffer refill
while self.buffer.empty():
time.sleep(0.001)
yield self.buffer.get()

if space >= self.num_workers:
more_data = min(space, self.steps_per_epoch() - self.count)
more_data = max(more_data, 0)
if more_data > 0:
self.enqueue(more_data)
self.epoch_finished = True
self.needs_data.set()
self.enqueue_future.result()

def shuffle_and_batch(self, sharding):
self.should_shuffle = True

ds = prefetch_to_single_device(
iter(self), 2, sharding, n_step_jit=self.n_jit_steps > 1
)

return ds

def batch(self, sharding) -> Iterator[jax.Array]:
@@ -555,6 +590,15 @@ def batch(self, sharding) -> Iterator[jax.Array]:
def make_signature(self) -> None:
pass

def cleanup(self):
self.epoch_finished = True
self.needs_data.set()
self.enqueue_future.result()
self.needs_data.clear()
self.thread_pool.shutdown(wait=True, cancel_futures=True)
self.process_pool.shutdown(wait=True, cancel_futures=True)
self.buffer.queue.clear()


dataset_dict = {
"cached": CachedInMemoryDataset,
21 changes: 9 additions & 12 deletions apax/data/preprocessing.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from matscipy.neighbours import neighbour_list
from vesin import NeighborList

log = logging.getLogger(__name__)

@@ -32,14 +32,10 @@ def compute_nl(positions, box, r_max):
"""
if np.all(box < 1e-6):
box, box_origin = get_shrink_wrapped_cell(positions)
idxs_i, idxs_j = neighbour_list(
"ij",
positions=positions,
cutoff=r_max,
cell=box,
cell_origin=box_origin,
pbc=[False, False, False],
box, _ = get_shrink_wrapped_cell(positions)
calculator = NeighborList(cutoff=r_max, full_list=True)
idxs_i, idxs_j = calculator.compute(
points=positions, box=box, periodic=False, quantities="ij"
)

neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)
@@ -49,10 +45,11 @@ def compute_nl(positions, box, r_max):

else:
positions = positions @ box
idxs_i, idxs_j, offsets = neighbour_list(
"ijS", positions=positions, cutoff=r_max, cell=box, pbc=[True, True, True]
calculator = NeighborList(cutoff=r_max, full_list=True)
idxs_i, idxs_j, offsets = calculator.compute(
points=positions, box=box, periodic=True, quantities="ijS"
)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32)
offsets = np.matmul(offsets, box)
return neighbor_idxs, offsets

22 changes: 16 additions & 6 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
@@ -9,8 +9,8 @@
from ase.calculators.calculator import Calculator, all_changes
from ase.calculators.singlepoint import SinglePointCalculator
from flax.core.frozen_dict import freeze, unfreeze
from matscipy.neighbours import neighbour_list
from tqdm import trange
from vesin import NeighborList

from apax.data.input_pipeline import (
CachedInMemoryDataset,
@@ -149,7 +149,7 @@ def __init__(
Function transformations applied on top of the EnergyDerivativeModel.
Transfomrations are implemented under `apax.md.transformations`.
padding_factor:
Multiple of the fallback Matscipy NL's amount of neighbors.
Multiple of the fallback vesin's amount of neighbors.
This NL will be padded to `len(neighbors) * padding_factor`
on NL initialization.
"""
@@ -209,15 +209,25 @@ def initialize(self, atoms):
else:
self.neighbors = self.neighbor_fn.allocate(positions)
else:
idxs_i = neighbour_list("i", atoms, self.r_max)
calculator = NeighborList(cutoff=self.r_max, full_list=True)
idxs_i, _, _ = calculator.compute(
points=atoms.positions,
box=atoms.cell.array,
periodic=np.any(atoms.pbc),
quantities="ijS",
)
self.padded_length = int(len(idxs_i) * self.padding_factor)

def set_neighbours_and_offsets(self, atoms, box):
idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms, self.r_max)

calculator = NeighborList(cutoff=self.r_max, full_list=True)
idxs_i, idxs_j, offsets = calculator.compute(
points=atoms.positions,
box=atoms.cell.array,
periodic=np.any(atoms.pbc),
quantities="ijS",
)
if len(idxs_i) > self.padded_length:
print("neighbor list overflowed, extending.")
self.padded_length = int(len(idxs_i) * self.padding_factor)
self.initialize(atoms)

zeros_to_add = self.padded_length - len(idxs_i)
2,932 changes: 1,374 additions & 1,558 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -22,7 +22,6 @@ clu = "^0.0.7"
jaxtyping = "^0.2.8"
typer = "^0.7.0"
lazy-loader = "^0.2"
matscipy = "^0.8.0"
znh5md = "^0.3"
pydantic = "^2.3.0"
jax = "^0.4.25"
@@ -32,6 +31,7 @@ orbax-checkpoint = "0.5.16"
flax = "0.8.4"
uncertainty-toolbox = "^0.1.1"
e3x = "^1.0.2"
vesin = "^0.2.0"

[tool.poetry.extras]
zntrack = ["zntrack"]