4
4
import argparse
5
5
import os
6
6
from dataclasses import dataclass , field
7
+ from timeit import default_timer as timer
7
8
from typing import Any , Dict , List , Union
8
9
10
+ import numpy as np
11
+
9
12
# torch imports
10
13
import torch
11
14
import torch .nn .functional as F
18
21
from torchtrain .datasets import create_tokenizer , dataloader_fn
19
22
from torchtrain .logging_utils import init_logger , rank0_log
20
23
from torchtrain .lr_scheduling import get_lr_scheduler
21
- from torchtrain .metrics import get_num_params , GPUMemoryMonitor
24
+ from torchtrain .metrics import build_metric_logger , get_num_params , GPUMemoryMonitor
22
25
23
26
from torchtrain .models import model_name_to_cls , model_name_to_tokenizer , models_config
24
27
from torchtrain .parallelisms import models_parallelize_fns , ParallelDims
25
28
26
29
from torchtrain .profiling import maybe_run_profiler
30
+ from torchtrain .utils import dist_max , dist_mean
27
31
28
32
29
33
@dataclass
@@ -126,7 +130,7 @@ def main(args):
126
130
127
131
scaler = build_grad_scaler (model )
128
132
129
- # TODO: add metrics
133
+ metric_logger = build_metric_logger ()
130
134
131
135
# torch.compile model for improved performance
132
136
if args .compile :
@@ -156,13 +160,18 @@ def main(args):
156
160
157
161
with maybe_run_profiler () as torch_profiler :
158
162
checkpoint .reset ()
163
+ # variables used to keep info for metrics logging
164
+ losses_since_last_log : List [float ] = []
165
+ nwords_since_last_log = 0
166
+ time_last_log = timer ()
159
167
while train_state .step < args .steps or args .steps == - 1 :
160
168
train_state .step += 1
161
169
# get batch
162
170
batch = next (iter (data_loader ))
163
171
input_ids , labels = batch
164
172
input_ids = input_ids .cuda ()
165
173
labels = labels .cuda ()
174
+ nwords_since_last_log += labels .numel ()
166
175
167
176
optimizer .zero_grad ()
168
177
@@ -194,6 +203,32 @@ def main(args):
194
203
195
204
train_state .current_loss = loss .item ()
196
205
train_state .losses .append (train_state .current_loss )
206
+ losses_since_last_log .append (train_state .current_loss )
207
+
208
+ # log metrics
209
+ if (train_state .step - 1 ) % args .log_freq == 0 :
210
+ avg_loss , max_loss = np .mean (losses_since_last_log ), np .max (
211
+ losses_since_last_log
212
+ )
213
+ global_avg_loss , global_max_loss = dist_mean (
214
+ avg_loss , world_mesh
215
+ ), dist_max (max_loss , world_mesh )
216
+
217
+ time_delta = timer () - time_last_log
218
+ wps = nwords_since_last_log / (
219
+ time_delta * parallel_dims .model_parallel_size
220
+ )
221
+
222
+ metrics = {
223
+ "global_avg_loss" : global_avg_loss ,
224
+ "global_max_loss" : global_max_loss ,
225
+ "wps" : wps ,
226
+ }
227
+ metric_logger .log (metrics , step = train_state .step )
228
+
229
+ losses_since_last_log .clear ()
230
+ nwords_since_last_log = 0
231
+ time_last_log = timer ()
197
232
198
233
rank0_log (
199
234
f"step: { train_state .step } , current loss: { train_state .current_loss } , lr: { scheduler .get_last_lr ()} "
@@ -202,6 +237,7 @@ def main(args):
202
237
203
238
checkpoint .save (train_state .step , force = (train_state .step == args .steps ))
204
239
240
+ metric_logger .close ()
205
241
rank0_log (f"{ gpu_metrics .get_current_stats ()} " )
206
242
207
243
@@ -294,6 +330,12 @@ def main(args):
294
330
"is an empty string, checkpointing is disabled."
295
331
),
296
332
)
333
+ parser .add_argument (
334
+ "--log_freq" ,
335
+ type = int ,
336
+ default = 10 ,
337
+ help = "how often to log metrics to TensorBoard" ,
338
+ )
297
339
298
340
args = parser .parse_args ()
299
341
main (args )
0 commit comments