Skip to content

Commit 7d9e872

Browse files
authoredJan 17, 2024
Merge pull request #223 from apax-hub/dev
Version 0.3.0 changes
2 parents a27da8e + 2b51835 commit 7d9e872

File tree

12 files changed

+1119
-1006
lines changed

12 files changed

+1119
-1006
lines changed
 

‎.github/workflows/linting.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
uses: psf/black@stable
1414
with:
1515
src: "./apax"
16-
version: "22.10.0"
16+
version: "22.12.0"
1717

1818
isort:
1919
runs-on: ubuntu-latest
@@ -25,7 +25,7 @@ jobs:
2525

2626
- name: Install isort
2727
run: |
28-
pip install isort==5.10.1
28+
pip install isort==5.12.0
2929
3030
- name: run isort
3131
run: |

‎apax/config/train_config.py

+5
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,11 @@ class Config(BaseModel, frozen=True, extra="forbid"):
279279
----------
280280
281281
n_epochs: Number of training epochs.
282+
patience: Number of epochs without improvement before trainings gets terminated.
282283
seed: Random seed.
284+
n_models: Number of models to be trained at once.
285+
n_jitted_steps: Number of train batches to be processed in a compiled loop.
286+
Can yield singificant speedups for small structures or small batch sizes.
283287
data: :class: `Data` <config.DataConfig> configuration.
284288
model: :class: `Model` <config.ModelConfig> configuration.
285289
metrics: List of :class: `metric` <config.MetricsConfig> configurations.
@@ -294,6 +298,7 @@ class Config(BaseModel, frozen=True, extra="forbid"):
294298
patience: Optional[PositiveInt] = None
295299
seed: int = 1
296300
n_models: int = 1
301+
n_jitted_steps: int = 1
297302

298303
data: DataConfig
299304
model: ModelConfig = ModelConfig()

‎apax/data/input_pipeline.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def __init__(
174174
"""
175175
self.n_epoch = n_epoch
176176
self.batch_size = None
177+
self.n_jit_steps = 1
177178
self.buffer_size = buffer_size
178179

179180
max_atoms, max_nbrs = find_largest_system(inputs)
@@ -187,6 +188,9 @@ def __init__(
187188
def set_batch_size(self, batch_size: int):
188189
self.batch_size = self.validate_batch_size(batch_size)
189190

191+
def batch_multiple_steps(self, n_steps: int):
192+
self.n_jit_steps = n_steps
193+
190194
def _check_batch_size(self):
191195
if self.batch_size is None:
192196
raise ValueError("Dataset Batch Size has not been set yet")
@@ -208,7 +212,7 @@ def steps_per_epoch(self) -> int:
208212
number of steps, and all batches have the same length. To do so, some training
209213
data are dropped in each epoch.
210214
"""
211-
return self.n_data // self.batch_size
215+
return self.n_data // self.batch_size // self.n_jit_steps
212216

213217
def init_input(self) -> Dict[str, np.ndarray]:
214218
"""Returns first batch of inputs and labels to init the model."""
@@ -240,15 +244,18 @@ def shuffle_and_batch(self) -> Iterator[jax.Array]:
240244
Iterator that returns inputs and labels of one batch in each step.
241245
"""
242246
self._check_batch_size()
243-
shuffled_ds = (
247+
ds = (
244248
self.ds.shuffle(buffer_size=self.buffer_size)
245249
.repeat(self.n_epoch)
246250
.batch(batch_size=self.batch_size)
247251
.map(PadToSpecificSize(self.max_atoms, self.max_nbrs))
248252
)
249253

250-
shuffled_ds = prefetch_to_single_device(shuffled_ds.as_numpy_iterator(), 2)
251-
return shuffled_ds
254+
if self.n_jit_steps > 1:
255+
ds = ds.batch(batch_size=self.n_jit_steps)
256+
257+
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
258+
return ds
252259

253260
def batch(self) -> Iterator[jax.Array]:
254261
self._check_batch_size()

‎apax/md/ase_calc.py

+5
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
self.model_config, self.params = restore_parameters(model_dir)
113113
self.n_models = check_for_ensemble(self.params)
114114
self.padding_factor = padding_factor
115+
self.padded_length = 0
115116

116117
if self.model_config.model.calc_stress:
117118
self.implemented_properties.append("stress")
@@ -148,6 +149,10 @@ def initialize(self, atoms):
148149
self.step = get_step_fn(model, atoms, self.neigbor_from_jax)
149150
self.neighbor_fn = neighbor_fn
150151

152+
if self.neigbor_from_jax:
153+
positions = jnp.asarray(atoms.positions, dtype=jnp.float64)
154+
self.neighbors = self.neighbor_fn.allocate(positions)
155+
151156
def set_neighbours_and_offsets(self, atoms, box):
152157
idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms, self.r_max)
153158

‎apax/md/nvt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def body_fn(i, state):
271271
)
272272
ckpt = {"state": state, "step": step}
273273
checkpoints.save_checkpoint(
274-
ckpt_dir=ckpt_dir,
274+
ckpt_dir=ckpt_dir.resolve(),
275275
target=ckpt,
276276
step=step,
277277
overwrite=True,

‎apax/train/checkpoints.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ class CheckpointManager:
8989
def __init__(self) -> None:
9090
self.async_manager = checkpoints.AsyncManager()
9191

92-
def save_checkpoint(self, ckpt, epoch: int, path: str) -> None:
92+
def save_checkpoint(self, ckpt, epoch: int, path: Path) -> None:
9393
checkpoints.save_checkpoint(
94-
ckpt_dir=path,
94+
ckpt_dir=path.resolve(),
9595
target=ckpt,
9696
step=epoch,
9797
overwrite=True,
@@ -147,7 +147,11 @@ def restore_single_parameters(model_dir: Path) -> Tuple[Config, FrozenDict]:
147147
"""
148148
model_dir = Path(model_dir)
149149
model_config = parse_config(model_dir / "config.yaml")
150-
model_config.data.directory = model_dir.parent.resolve().as_posix()
150+
151+
if model_config.data.experiment == "":
152+
model_config.data.directory = model_dir.resolve().as_posix()
153+
else:
154+
model_config.data.directory = model_dir.parent.resolve().as_posix()
151155

152156
ckpt_dir = model_config.data.model_version_path
153157
return model_config, load_params(ckpt_dir)

‎apax/train/loss.py

+9
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,21 @@ def force_angle_exponential_weight(
5252
return (1.0 - dotp) * jnp.exp(-F_0_norm) / divisor
5353

5454

55+
def stress_tril(label, prediction, divisor=1.0):
56+
idxs = jnp.tril_indices(3)
57+
label_tril = label[:, idxs[0], idxs[1]]
58+
prediction_tril = prediction[:, idxs[0], idxs[1]]
59+
return (label_tril - prediction_tril) ** 2 / divisor
60+
61+
5562
loss_functions = {
5663
"molecules": weighted_squared_error,
5764
"structures": weighted_squared_error,
5865
"vibrations": weighted_squared_error,
5966
"cosine_sim": force_angle_loss,
6067
"cosine_sim_div_magnitude": force_angle_div_force_label,
6168
"cosine_sim_exp_magnitude": force_angle_exponential_weight,
69+
"tril": stress_tril,
6270
}
6371

6472

@@ -101,6 +109,7 @@ def determine_divisor(self, n_atoms: jnp.array) -> jnp.array:
101109
n_atoms, "batch -> batch 1 1"
102110
),
103111
"stress_structures": einops.repeat(n_atoms**2, "batch -> batch 1 1"),
112+
"stress_tril": einops.repeat(n_atoms**2, "batch -> batch 1 1"),
104113
"stress_vibrations": einops.repeat(n_atoms, "batch -> batch 1 1"),
105114
}
106115
divisor = divisor_dict.get(divisor_id, jnp.array(1.0))

‎apax/train/metrics.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
log = logging.getLogger(__name__)
1010

1111

12-
class RootAverage(metrics.Average):
12+
class Averagefp64(metrics.Average):
13+
@classmethod
14+
def empty(cls) -> metrics.Metric:
15+
return cls(total=jnp.array(0, jnp.float64), count=jnp.array(0, jnp.int64))
16+
17+
18+
class RootAverage(Averagefp64):
1319
"""
1420
Modifies the `compute` method of `metrics.Average` to obtain the root of the average.
1521
Meant to be used with `mse_fn`.
@@ -59,7 +65,7 @@ def make_single_metric(key: str, reduction: str) -> metrics.Average:
5965
if reduction == "rmse":
6066
metric = RootAverage
6167
else:
62-
metric = metrics.Average
68+
metric = Averagefp64
6369

6470
reduction_fn = reduction_fns[reduction]
6571
reduction_fn = partial(reduction_fn, key=key)

‎apax/train/run.py

+1
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,5 @@ def run(user_config, log_level="error"):
117117
patience=config.patience,
118118
disable_pbar=config.progress_bar.disable_epoch_pbar,
119119
is_ensemble=config.n_models > 1,
120+
n_jitted_steps=config.n_jitted_steps,
120121
)

‎apax/train/trainer.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,36 @@
1+
import functools
12
import logging
23
import time
34
from functools import partial
4-
from typing import Callable
5+
from typing import Callable, Optional
56

67
import jax
78
import jax.numpy as jnp
89
import numpy as np
10+
from clu import metrics
911
from tqdm import trange
1012

13+
from apax.data.input_pipeline import AtomisticDataset
1114
from apax.train.checkpoints import CheckpointManager, load_state
1215

1316
log = logging.getLogger(__name__)
1417

1518

1619
def fit(
1720
state,
18-
train_ds,
21+
train_ds: AtomisticDataset,
1922
loss_fn,
20-
Metrics,
21-
callbacks,
22-
n_epochs,
23+
Metrics: metrics.Collection,
24+
callbacks: list,
25+
n_epochs: int,
2326
ckpt_dir,
2427
ckpt_interval: int = 1,
25-
val_ds=None,
28+
val_ds: Optional[AtomisticDataset] = None,
2629
sam_rho=0.0,
27-
patience=None,
30+
patience: Optional[int] = None,
2831
disable_pbar: bool = False,
2932
is_ensemble=False,
33+
n_jitted_steps=1,
3034
):
3135
log.info("Beginning Training")
3236
callbacks.on_train_begin()
@@ -38,13 +42,16 @@ def fit(
3842
train_step, val_step = make_step_fns(
3943
loss_fn, Metrics, model=state.apply_fn, sam_rho=sam_rho, is_ensemble=is_ensemble
4044
)
45+
if n_jitted_steps > 1:
46+
train_step = jax.jit(functools.partial(jax.lax.scan, train_step))
4147

4248
state, start_epoch = load_state(state, latest_dir)
4349
if start_epoch >= n_epochs:
4450
raise ValueError(
4551
f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})"
4652
)
4753

54+
train_ds.batch_multiple_steps(n_jitted_steps)
4855
train_steps_per_epoch = train_ds.steps_per_epoch()
4956
batch_train_ds = train_ds.shuffle_and_batch()
5057

@@ -68,12 +75,16 @@ def fit(
6875
for batch_idx in range(train_steps_per_epoch):
6976
callbacks.on_train_batch_begin(batch=batch_idx)
7077

71-
inputs, labels = next(batch_train_ds)
72-
batch_loss, train_batch_metrics, state = train_step(
73-
state, inputs, labels, train_batch_metrics
78+
batch = next(batch_train_ds)
79+
(
80+
(state, train_batch_metrics),
81+
batch_loss,
82+
) = train_step(
83+
(state, train_batch_metrics),
84+
batch,
7485
)
7586

76-
epoch_loss["train_loss"] += batch_loss
87+
epoch_loss["train_loss"] += jnp.mean(batch_loss)
7788
callbacks.on_train_batch_end(batch=batch_idx)
7889

7990
epoch_loss["train_loss"] /= train_steps_per_epoch
@@ -88,10 +99,10 @@ def fit(
8899
epoch_loss.update({"val_loss": 0.0})
89100
val_batch_metrics = Metrics.empty()
90101
for batch_idx in range(val_steps_per_epoch):
91-
inputs, labels = next(batch_val_ds)
102+
batch = next(batch_val_ds)
92103

93104
batch_loss, val_batch_metrics = val_step(
94-
state.params, inputs, labels, val_batch_metrics
105+
state.params, batch, val_batch_metrics
95106
)
96107
epoch_loss["val_loss"] += batch_loss
97108

@@ -213,17 +224,22 @@ def update_step(state, inputs, labels):
213224
eval_fn = loss_calculator
214225

215226
@jax.jit
216-
def train_step(state, inputs, labels, batch_metrics):
227+
def train_step(carry, batch):
228+
state, batch_metrics = carry
229+
inputs, labels = batch
217230
loss, predictions, state = update_fn(state, inputs, labels)
218231

219232
new_batch_metrics = Metrics.single_from_model_output(
220233
label=labels, prediction=predictions
221234
)
222235
batch_metrics = batch_metrics.merge(new_batch_metrics)
223-
return loss, batch_metrics, state
236+
237+
new_carry = (state, batch_metrics)
238+
return new_carry, loss
224239

225240
@jax.jit
226-
def val_step(params, inputs, labels, batch_metrics):
241+
def val_step(params, batch, batch_metrics):
242+
inputs, labels = batch
227243
loss, predictions = eval_fn(params, inputs, labels)
228244

229245
new_batch_metrics = Metrics.single_from_model_output(

0 commit comments

Comments
 (0)
Please sign in to comment.