Skip to content

Commit 80207e6

Browse files
committed
add TensorBoard logging with loss and wps
ghstack-source-id: d0828f16c06747a5af2586630e5205bf786de1c4 Pull Request resolved: #57
1 parent 40c93e9 commit 80207e6

File tree

7 files changed

+135
-2
lines changed

7 files changed

+135
-2
lines changed

README.md

+18
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,21 @@ run the llama debug model locally to verify the setup is correct:
2222
```
2323
./run_llama_train.sh
2424
```
25+
26+
# TensorBoard
27+
28+
To visualize training metrics on TensorBoard:
29+
30+
1. (by default) set `enable_tensorboard = true` in `torchtrain/train_configs/train_config.toml`
31+
32+
2. set up SSH tunneling
33+
```
34+
ssh -L 6006:127.0.0.1:6006 [username]@[hostname]
35+
```
36+
37+
3. then in the torchtrain repo
38+
```
39+
tensorboard --logdir=./torchtrain/outputs/tb
40+
```
41+
42+
4. go to the URL it provides OR to http://localhost:6006/

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ torch >= 2.2.0.dev
22
sentencepiece
33
datasets
44
tomli >= 1.1.0 ; python_version < "3.11"
5+
tensorboard

torchtrain/metrics.py

+44
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44
# Copyright (c) Meta Platforms, Inc. and affiliates.
55
# All rights reserved
66

7+
import os
78
from collections import namedtuple
9+
from datetime import datetime
10+
from typing import Any, Dict, Optional
811

912
import torch
1013
import torch.nn as nn
14+
from torch.utils.tensorboard import SummaryWriter
15+
16+
from torchtrain.logging_utils import rank0_log
17+
from torchtrain.profiling import get_config_from_toml
1118

1219
_gb_in_bytes = 1024 * 1024 * 1024
1320
_mb_in_bytes = 1024 * 1024
@@ -187,3 +194,40 @@ def get_num_params(model: nn.Module, only_trainable: bool = False) -> int:
187194
param_list = [p for p in param_list if p.requires_grad]
188195
unique_params = {p.data_ptr(): p for p in param_list}.values()
189196
return sum(p.numel() for p in unique_params)
197+
198+
199+
class MetricLogger:
200+
def __init__(self, log_dir, tag, enable_tb):
201+
self.tag = tag
202+
self.writer: Optional[SummaryWriter] = None
203+
if enable_tb:
204+
self.writer = SummaryWriter(log_dir, max_queue=1000)
205+
206+
def log(self, metrics: Dict[str, Any], step: int):
207+
if self.writer is not None:
208+
for k, v in metrics.items():
209+
tag = k if self.tag is None else f"{self.tag}/{k}"
210+
self.writer.add_scalar(tag, v, step)
211+
212+
def close(self):
213+
if self.writer is not None:
214+
self.writer.close()
215+
216+
217+
def build_metric_logger(tag: Optional[str] = None):
218+
config = get_config_from_toml()
219+
220+
dump_dir = config["global"]["dump_folder"]
221+
save_tb_folder = config["metrics"]["save_tb_folder"]
222+
# since we don't have run id yet, use current minute as identifier
223+
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
224+
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)
225+
226+
enable_tb = config["metrics"].get("enable_tensorboard", False)
227+
if enable_tb:
228+
rank0_log(
229+
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}."
230+
)
231+
232+
rank_str = f"rank_{torch.distributed.get_rank()}"
233+
return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb)

torchtrain/parallelisms/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import logging
55
from dataclasses import dataclass
6+
from functools import cached_property
67

78
from torch.distributed.device_mesh import init_device_mesh
89

@@ -61,3 +62,7 @@ def sp_enabled(self):
6162
@property
6263
def pp_enabled(self):
6364
return self.pp > 1
65+
66+
@cached_property
67+
def model_parallel_size(self):
68+
return self.sp * self.pp

torchtrain/train_configs/train_config.toml

+4
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ run_profiler = true
77
save_traces_folder = "profiling/traces"
88
# profiling frequency - example: 10 means every 10th iter will be profiled
99
profile_every_x_iter = 10
10+
11+
[metrics]
12+
enable_tensorboard = true
13+
save_tb_folder = "tb"

torchtrain/utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3+
4+
from typing import Union
5+
6+
import torch
7+
import torch.distributed._functional_collectives as funcol
8+
import torch.distributed.distributed_c10d as c10d
9+
from torch.distributed.device_mesh import DeviceMesh
10+
11+
12+
def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
13+
tensor = torch.tensor(x).cuda()
14+
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh)
15+
16+
17+
def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
18+
tensor = torch.tensor(x).cuda()
19+
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh)

train.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import argparse
55
import os
66
from dataclasses import dataclass, field
7+
from timeit import default_timer as timer
78
from typing import Any, Dict, List, Union
89

10+
import numpy as np
11+
912
# torch imports
1013
import torch
1114
import torch.nn.functional as F
@@ -18,12 +21,13 @@
1821
from torchtrain.datasets import create_tokenizer, dataloader_fn
1922
from torchtrain.logging_utils import init_logger, rank0_log
2023
from torchtrain.lr_scheduling import get_lr_scheduler
21-
from torchtrain.metrics import get_num_params, GPUMemoryMonitor
24+
from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor
2225

2326
from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
2427
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims
2528

2629
from torchtrain.profiling import maybe_run_profiler
30+
from torchtrain.utils import dist_max, dist_mean
2731

2832

2933
@dataclass
@@ -126,7 +130,7 @@ def main(args):
126130

127131
scaler = build_grad_scaler(model)
128132

129-
# TODO: add metrics
133+
metric_logger = build_metric_logger()
130134

131135
# torch.compile model for improved performance
132136
if args.compile:
@@ -156,13 +160,18 @@ def main(args):
156160

157161
with maybe_run_profiler() as torch_profiler:
158162
checkpoint.reset()
163+
# variables used to keep info for metrics logging
164+
losses_since_last_log: List[float] = []
165+
nwords_since_last_log = 0
166+
time_last_log = timer()
159167
while train_state.step < args.steps or args.steps == -1:
160168
train_state.step += 1
161169
# get batch
162170
batch = next(iter(data_loader))
163171
input_ids, labels = batch
164172
input_ids = input_ids.cuda()
165173
labels = labels.cuda()
174+
nwords_since_last_log += labels.numel()
166175

167176
optimizer.zero_grad()
168177

@@ -194,6 +203,32 @@ def main(args):
194203

195204
train_state.current_loss = loss.item()
196205
train_state.losses.append(train_state.current_loss)
206+
losses_since_last_log.append(train_state.current_loss)
207+
208+
# log metrics
209+
if (train_state.step - 1) % args.log_freq == 0:
210+
avg_loss, max_loss = np.mean(losses_since_last_log), np.max(
211+
losses_since_last_log
212+
)
213+
global_avg_loss, global_max_loss = dist_mean(
214+
avg_loss, world_mesh
215+
), dist_max(max_loss, world_mesh)
216+
217+
time_delta = timer() - time_last_log
218+
wps = nwords_since_last_log / (
219+
time_delta * parallel_dims.model_parallel_size
220+
)
221+
222+
metrics = {
223+
"global_avg_loss": global_avg_loss,
224+
"global_max_loss": global_max_loss,
225+
"wps": wps,
226+
}
227+
metric_logger.log(metrics, step=train_state.step)
228+
229+
losses_since_last_log.clear()
230+
nwords_since_last_log = 0
231+
time_last_log = timer()
197232

198233
rank0_log(
199234
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}"
@@ -202,6 +237,7 @@ def main(args):
202237

203238
checkpoint.save(train_state.step, force=(train_state.step == args.steps))
204239

240+
metric_logger.close()
205241
rank0_log(f"{gpu_metrics.get_current_stats()}")
206242

207243

@@ -294,6 +330,12 @@ def main(args):
294330
"is an empty string, checkpointing is disabled."
295331
),
296332
)
333+
parser.add_argument(
334+
"--log_freq",
335+
type=int,
336+
default=10,
337+
help="how often to log metrics to TensorBoard",
338+
)
297339

298340
args = parser.parse_args()
299341
main(args)

0 commit comments

Comments
 (0)