1
+ import functools
1
2
import logging
2
3
import time
3
4
from functools import partial
4
- from typing import Callable
5
+ from typing import Callable , Optional
5
6
6
7
import jax
7
8
import jax .numpy as jnp
8
9
import numpy as np
10
+ from clu import metrics
9
11
from tqdm import trange
10
12
13
+ from apax .data .input_pipeline import AtomisticDataset
11
14
from apax .train .checkpoints import CheckpointManager , load_state
12
15
13
16
log = logging .getLogger (__name__ )
14
17
15
18
16
19
def fit (
17
20
state ,
18
- train_ds ,
21
+ train_ds : AtomisticDataset ,
19
22
loss_fn ,
20
- Metrics ,
21
- callbacks ,
22
- n_epochs ,
23
+ Metrics : metrics . Collection ,
24
+ callbacks : list ,
25
+ n_epochs : int ,
23
26
ckpt_dir ,
24
27
ckpt_interval : int = 1 ,
25
- val_ds = None ,
28
+ val_ds : Optional [ AtomisticDataset ] = None ,
26
29
sam_rho = 0.0 ,
27
- patience = None ,
30
+ patience : Optional [ int ] = None ,
28
31
disable_pbar : bool = False ,
29
32
is_ensemble = False ,
33
+ n_jitted_steps = 1 ,
30
34
):
31
35
log .info ("Beginning Training" )
32
36
callbacks .on_train_begin ()
@@ -38,13 +42,16 @@ def fit(
38
42
train_step , val_step = make_step_fns (
39
43
loss_fn , Metrics , model = state .apply_fn , sam_rho = sam_rho , is_ensemble = is_ensemble
40
44
)
45
+ if n_jitted_steps > 1 :
46
+ train_step = jax .jit (functools .partial (jax .lax .scan , train_step ))
41
47
42
48
state , start_epoch = load_state (state , latest_dir )
43
49
if start_epoch >= n_epochs :
44
50
raise ValueError (
45
51
f"n_epochs <= current epoch from checkpoint ({ n_epochs } <= { start_epoch } )"
46
52
)
47
53
54
+ train_ds .batch_multiple_steps (n_jitted_steps )
48
55
train_steps_per_epoch = train_ds .steps_per_epoch ()
49
56
batch_train_ds = train_ds .shuffle_and_batch ()
50
57
@@ -68,12 +75,16 @@ def fit(
68
75
for batch_idx in range (train_steps_per_epoch ):
69
76
callbacks .on_train_batch_begin (batch = batch_idx )
70
77
71
- inputs , labels = next (batch_train_ds )
72
- batch_loss , train_batch_metrics , state = train_step (
73
- state , inputs , labels , train_batch_metrics
78
+ batch = next (batch_train_ds )
79
+ (
80
+ (state , train_batch_metrics ),
81
+ batch_loss ,
82
+ ) = train_step (
83
+ (state , train_batch_metrics ),
84
+ batch ,
74
85
)
75
86
76
- epoch_loss ["train_loss" ] += batch_loss
87
+ epoch_loss ["train_loss" ] += jnp . mean ( batch_loss )
77
88
callbacks .on_train_batch_end (batch = batch_idx )
78
89
79
90
epoch_loss ["train_loss" ] /= train_steps_per_epoch
@@ -88,10 +99,10 @@ def fit(
88
99
epoch_loss .update ({"val_loss" : 0.0 })
89
100
val_batch_metrics = Metrics .empty ()
90
101
for batch_idx in range (val_steps_per_epoch ):
91
- inputs , labels = next (batch_val_ds )
102
+ batch = next (batch_val_ds )
92
103
93
104
batch_loss , val_batch_metrics = val_step (
94
- state .params , inputs , labels , val_batch_metrics
105
+ state .params , batch , val_batch_metrics
95
106
)
96
107
epoch_loss ["val_loss" ] += batch_loss
97
108
@@ -213,17 +224,22 @@ def update_step(state, inputs, labels):
213
224
eval_fn = loss_calculator
214
225
215
226
@jax .jit
216
- def train_step (state , inputs , labels , batch_metrics ):
227
+ def train_step (carry , batch ):
228
+ state , batch_metrics = carry
229
+ inputs , labels = batch
217
230
loss , predictions , state = update_fn (state , inputs , labels )
218
231
219
232
new_batch_metrics = Metrics .single_from_model_output (
220
233
label = labels , prediction = predictions
221
234
)
222
235
batch_metrics = batch_metrics .merge (new_batch_metrics )
223
- return loss , batch_metrics , state
236
+
237
+ new_carry = (state , batch_metrics )
238
+ return new_carry , loss
224
239
225
240
@jax .jit
226
- def val_step (params , inputs , labels , batch_metrics ):
241
+ def val_step (params , batch , batch_metrics ):
242
+ inputs , labels = batch
227
243
loss , predictions = eval_fn (params , inputs , labels )
228
244
229
245
new_batch_metrics = Metrics .single_from_model_output (
0 commit comments