Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT multiple train steps #220

Merged
merged 9 commits into from
Jan 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/linting.yaml
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ jobs:
uses: psf/black@stable
with:
src: "./apax"
version: "22.10.0"
version: "22.12.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use pre-commit.ci now that the repo is public

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be part of this PR or a separate one? I looked over it and the rest looks good for me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be part of this PR or a separate one? I looked over it and the rest looks good for me.

separate


isort:
runs-on: ubuntu-latest
@@ -25,7 +25,7 @@ jobs:

- name: Install isort
run: |
pip install isort==5.10.1
pip install isort==5.12.0

- name: run isort
run: |
5 changes: 5 additions & 0 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
@@ -279,7 +279,11 @@ class Config(BaseModel, frozen=True, extra="forbid"):
----------

n_epochs: Number of training epochs.
patience: Number of epochs without improvement before trainings gets terminated.
seed: Random seed.
n_models: Number of models to be trained at once.
n_jitted_steps: Number of train batches to be processed in a compiled loop.
Can yield singificant speedups for small structures or small batch sizes.
data: :class: `Data` <config.DataConfig> configuration.
model: :class: `Model` <config.ModelConfig> configuration.
metrics: List of :class: `metric` <config.MetricsConfig> configurations.
@@ -294,6 +298,7 @@ class Config(BaseModel, frozen=True, extra="forbid"):
patience: Optional[PositiveInt] = None
seed: int = 1
n_models: int = 1
n_jitted_steps: int = 1

data: DataConfig
model: ModelConfig = ModelConfig()
15 changes: 11 additions & 4 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -174,6 +174,7 @@ def __init__(
"""
self.n_epoch = n_epoch
self.batch_size = None
self.n_jit_steps = 1
self.buffer_size = buffer_size

max_atoms, max_nbrs = find_largest_system(inputs)
@@ -187,6 +188,9 @@ def __init__(
def set_batch_size(self, batch_size: int):
self.batch_size = self.validate_batch_size(batch_size)

def batch_multiple_steps(self, n_steps: int):
self.n_jit_steps = n_steps

def _check_batch_size(self):
if self.batch_size is None:
raise ValueError("Dataset Batch Size has not been set yet")
@@ -208,7 +212,7 @@ def steps_per_epoch(self) -> int:
number of steps, and all batches have the same length. To do so, some training
data are dropped in each epoch.
"""
return self.n_data // self.batch_size
return self.n_data // self.batch_size // self.n_jit_steps

def init_input(self) -> Dict[str, np.ndarray]:
"""Returns first batch of inputs and labels to init the model."""
@@ -240,15 +244,18 @@ def shuffle_and_batch(self) -> Iterator[jax.Array]:
Iterator that returns inputs and labels of one batch in each step.
"""
self._check_batch_size()
shuffled_ds = (
ds = (
self.ds.shuffle(buffer_size=self.buffer_size)
.repeat(self.n_epoch)
.batch(batch_size=self.batch_size)
.map(PadToSpecificSize(self.max_atoms, self.max_nbrs))
)

shuffled_ds = prefetch_to_single_device(shuffled_ds.as_numpy_iterator(), 2)
return shuffled_ds
if self.n_jit_steps > 1:
ds = ds.batch(batch_size=self.n_jit_steps)

ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds

def batch(self) -> Iterator[jax.Array]:
self._check_batch_size()
10 changes: 8 additions & 2 deletions apax/train/metrics.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,13 @@
log = logging.getLogger(__name__)


class RootAverage(metrics.Average):
class Averagefp64(metrics.Average):
@classmethod
def empty(cls) -> metrics.Metric:
return cls(total=jnp.array(0, jnp.float64), count=jnp.array(0, jnp.int64))


class RootAverage(Averagefp64):
"""
Modifies the `compute` method of `metrics.Average` to obtain the root of the average.
Meant to be used with `mse_fn`.
@@ -59,7 +65,7 @@ def make_single_metric(key: str, reduction: str) -> metrics.Average:
if reduction == "rmse":
metric = RootAverage
else:
metric = metrics.Average
metric = Averagefp64

reduction_fn = reduction_fns[reduction]
reduction_fn = partial(reduction_fn, key=key)
1 change: 1 addition & 0 deletions apax/train/run.py
Original file line number Diff line number Diff line change
@@ -117,4 +117,5 @@ def run(user_config, log_level="error"):
patience=config.patience,
disable_pbar=config.progress_bar.disable_epoch_pbar,
is_ensemble=config.n_models > 1,
n_jitted_steps=config.n_jitted_steps,
)
48 changes: 32 additions & 16 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
import functools
import logging
import time
from functools import partial
from typing import Callable
from typing import Callable, Optional

import jax
import jax.numpy as jnp
import numpy as np
from clu import metrics
from tqdm import trange

from apax.data.input_pipeline import AtomisticDataset
from apax.train.checkpoints import CheckpointManager, load_state

log = logging.getLogger(__name__)


def fit(
state,
train_ds,
train_ds: AtomisticDataset,
loss_fn,
Metrics,
callbacks,
n_epochs,
Metrics: metrics.Collection,
callbacks: list,
n_epochs: int,
ckpt_dir,
ckpt_interval: int = 1,
val_ds=None,
val_ds: Optional[AtomisticDataset] = None,
sam_rho=0.0,
patience=None,
patience: Optional[int] = None,
disable_pbar: bool = False,
is_ensemble=False,
n_jitted_steps=1,
):
log.info("Beginning Training")
callbacks.on_train_begin()
@@ -38,13 +42,16 @@ def fit(
train_step, val_step = make_step_fns(
loss_fn, Metrics, model=state.apply_fn, sam_rho=sam_rho, is_ensemble=is_ensemble
)
if n_jitted_steps > 1:
train_step = jax.jit(functools.partial(jax.lax.scan, train_step))

state, start_epoch = load_state(state, latest_dir)
if start_epoch >= n_epochs:
raise ValueError(
f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})"
)

train_ds.batch_multiple_steps(n_jitted_steps)
train_steps_per_epoch = train_ds.steps_per_epoch()
batch_train_ds = train_ds.shuffle_and_batch()

@@ -68,12 +75,16 @@ def fit(
for batch_idx in range(train_steps_per_epoch):
callbacks.on_train_batch_begin(batch=batch_idx)

inputs, labels = next(batch_train_ds)
batch_loss, train_batch_metrics, state = train_step(
state, inputs, labels, train_batch_metrics
batch = next(batch_train_ds)
(
(state, train_batch_metrics),
batch_loss,
) = train_step(
(state, train_batch_metrics),
batch,
)

epoch_loss["train_loss"] += batch_loss
epoch_loss["train_loss"] += jnp.mean(batch_loss)
callbacks.on_train_batch_end(batch=batch_idx)

epoch_loss["train_loss"] /= train_steps_per_epoch
@@ -88,10 +99,10 @@ def fit(
epoch_loss.update({"val_loss": 0.0})
val_batch_metrics = Metrics.empty()
for batch_idx in range(val_steps_per_epoch):
inputs, labels = next(batch_val_ds)
batch = next(batch_val_ds)

batch_loss, val_batch_metrics = val_step(
state.params, inputs, labels, val_batch_metrics
state.params, batch, val_batch_metrics
)
epoch_loss["val_loss"] += batch_loss

@@ -213,17 +224,22 @@ def update_step(state, inputs, labels):
eval_fn = loss_calculator

@jax.jit
def train_step(state, inputs, labels, batch_metrics):
def train_step(carry, batch):
state, batch_metrics = carry
inputs, labels = batch
loss, predictions, state = update_fn(state, inputs, labels)

new_batch_metrics = Metrics.single_from_model_output(
label=labels, prediction=predictions
)
batch_metrics = batch_metrics.merge(new_batch_metrics)
return loss, batch_metrics, state

new_carry = (state, batch_metrics)
return new_carry, loss

@jax.jit
def val_step(params, inputs, labels, batch_metrics):
def val_step(params, batch, batch_metrics):
inputs, labels = batch
loss, predictions = eval_fn(params, inputs, labels)

new_batch_metrics = Metrics.single_from_model_output(