Skip to content

Commit

Permalink
improve logging
Browse files Browse the repository at this point in the history
ghstack-source-id: de61ec093b43a2ccbf1156c76ba81ecd698a6a8a
Pull Request resolved: #132
  • Loading branch information
tianyu-l committed Mar 13, 2024
1 parent 2722865 commit 2369861
Show file tree
Hide file tree
Showing 15 changed files with 99 additions and 124 deletions.
16 changes: 9 additions & 7 deletions torchtrain/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
set_model_state_dict,
set_optimizer_state_dict,
)
from torchtrain.logging_utils import rank0_log
from torchtrain.logging_utils import logger


class IntervalType(enum.Enum):
Expand Down Expand Up @@ -109,13 +109,13 @@ def save(self, curr_step: int, force: bool = False) -> None:
self.work = None
self.doit = None

rank0_log(f"Saving a checkpoint in step {curr_step}.")
logger.info(f"Saving a checkpoint at step {curr_step}")
begin = time.monotonic()
dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step))
self.reset()
rank0_log(
f"Finish saving the checkpoint in step {curr_step}. "
f"{time.monotonic() - begin} seconds"
logger.info(
f"Finished saving the checkpoint at step {curr_step} "
f"in {time.monotonic() - begin} seconds"
)

def load(self, step: int = -1) -> bool:
Expand All @@ -136,11 +136,13 @@ def load(self, step: int = -1) -> bool:
return False
step = max(step_counts)

rank0_log("Loading a checkpoint.")
logger.info("Loading a checkpoint")
begin = time.monotonic()
dcp.load(
self.states,
checkpoint_id=self.create_checkpoint_id(step),
)
rank0_log(f"Finish loading a checkpoint. {time.monotonic() - begin} seconds.")
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin} seconds"
)
return True
27 changes: 9 additions & 18 deletions torchtrain/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from torch.utils.data import DataLoader, IterableDataset

from torchtrain.datasets.tokenizer import TokenizerIf
from torchtrain.logging_utils import rank0_log
from torchtrain.utils import Color
from torchtrain.logging_utils import logger

from datasets import load_dataset, load_from_disk
from datasets.distributed import split_dataset_by_node
Expand Down Expand Up @@ -97,21 +96,17 @@ def __init__(
) -> None:
if dataset_name not in _supported_datasets:
raise ValueError(
f"Dataset {dataset_name} is not supported. Supported datasets are: {_supported_datasets.keys()}"
f"Dataset {dataset_name} is not supported. Supported datasets are: {_supported_datasets.keys()}."
)

# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
if dataset_path:
rank0_log(
f"{Color.green}Loading '{dataset_name}' dataset locally from {dataset_path}...{Color.reset}"
)
logger.info(f"Loading {dataset_name} dataset locally from {dataset_path}")
ds = load_from_disk(dataset_path)
else:
rank0_log(
f"{Color.green}Preparing '{dataset_name}' dataset from HuggingFace...{Color.reset}"
)
# Setting `streaming=True` works for large dataset, but the speed is slow.
logger.info(f"Preparing {dataset_name} dataset from HuggingFace")
# Setting `streaming=True` works for large dataset, but is slightly slower and unstable.
# c4 is huge, and requires both streaming and language selection (we default to en)
if dataset_name == "c4":
ds = load_dataset(
Expand Down Expand Up @@ -147,16 +142,12 @@ def __iter__(self):
label = x[1:]
yield input, label
if not self.infinite:
rank0_log(
f"{Color.red}WARNING:{Color.reset} dataset {Color.yellow}'{self.dataset_name}'{Color.reset} has "
f"run out of data.{Color.reset}"
)
logger.warning(f"Dataset {self.dataset_name} has run out of data.")
break
else:
# we are re-looping on the same dataset, warn user
rank0_log(
f"{Color.red}WARNING:{Color.reset} dataset {Color.yellow}'{self.dataset_name}'{Color.reset} is "
f"being re-looped. Loss related metrics might be misleading.{Color.reset}"
logger.warning(
f"Dataset {self.dataset_name} is being re-looped. "
"Loss related metrics might be misleading."
)


Expand Down
22 changes: 10 additions & 12 deletions torchtrain/datasets/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -8,13 +8,11 @@

import os
from abc import ABC, abstractmethod
from logging import getLogger
from typing import List

from sentencepiece import SentencePieceProcessor


logger = getLogger()
from torchtrain.logging_utils import logger


class TokenizerIf(ABC):
Expand All @@ -40,34 +38,34 @@ def n_words(self) -> int:


def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> TokenizerIf:
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
if tokenizer_type == "sentencepiece":
return SentencePieceTokenizer(tokenizer_path)
else:
raise ValueError(f"Unknown tokenizer type: {args.type}")


class SentencePieceTokenizer(TokenizerIf):
"""tokenizing and encoding/decoding text using SentencePiece."""
"""
Tokenizing and encoding/decoding text based on a SentencePiece model.
Args:
tokenizer_path (str): The path to the SentencePiece model file.
"""

def __init__(self, tokenizer_path: str):
"""
Initializes the Tokenizer with a SentencePiece model.

Args:
tokenizer_path (str): The path to the SentencePiece model file.
"""
super().__init__(tokenizer_path)
# reload tokenizer
self.sp_model = SentencePieceProcessor(model_file=tokenizer_path)
logger.info(f"Reloaded SentencePiece model from {tokenizer_path}")

# BOS / EOS token IDs
self._n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
f"SentencePieceTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}"
)
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

Expand Down
5 changes: 2 additions & 3 deletions torchtrain/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
# All rights reserved

from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import rank0_log
from torchtrain.logging_utils import logger
from torchtrain.models.llama import Transformer
from torchtrain.utils import Color


def build_fp8_linear(model: Transformer, job_config: JobConfig):
Expand Down Expand Up @@ -42,4 +41,4 @@ def build_fp8_linear(model: Transformer, job_config: JobConfig):

# Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type
swap_linear_with_float8_linear(model, float8_linear_type)
rank0_log(f"{Color.green}Using {linear_type} float8 linear layers{Color.reset}")
logger.info(f"Swapped to {linear_type} float8 linear layers")
9 changes: 3 additions & 6 deletions torchtrain/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import logging

import torch

logger = logging.getLogger()


def rank0_log(msg):
if torch.distributed.get_rank() == 0:
logger.info(msg)


def init_logger():
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
Expand Down
17 changes: 8 additions & 9 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torchtrain.config_manager import JobConfig

from torchtrain.logging_utils import rank0_log
from torchtrain.logging_utils import logger

# note that GiB (gibibyte) is 1024, vs GB is 1000
_gib_in_bytes = 1024 * 1024 * 1024
Expand Down Expand Up @@ -122,18 +121,18 @@ def get_current_stats(self, return_data: bool = False):
self.device_active_memory_usage, self.device_capacity, precision=2
)

display_str = ""
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%, "
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n"
display_str = f"{self.device_name} ({self.device_index}). "
display_str += f"Current memory: reserved {self.device_reserved_memory_pct}%, "
display_str += f"alloc {self.device_alloc_memory_pct}%, active {self.device_active_memory_pct}%. "

self.get_peak_stats(curr_mem)

peak_active_pct = self.get_pct_memory(self.peak_active_memory)
peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory)
peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory)
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n"
display_str += f"Peak memory: reserved {peak_reserved_pct}%, alloc {peak_allocated_pct}%, active {peak_active_pct}%. "

display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}"
display_str += f"Num retries: {self.num_retries}. Num ooms: {self.num_ooms}."
if self.num_retries > 0:
display_str += f"\nWARNING: {self.num_retries} retries -- recommend lowering batch size for max performance\n"

Expand Down Expand Up @@ -224,8 +223,8 @@ def build_metric_logger(config: JobConfig, tag: Optional[str] = None):

enable_tb = config.metrics.enable_tensorboard
if enable_tb:
rank0_log(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}."
logger.info(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}"
)

rank_str = f"rank_{torch.distributed.get_rank()}"
Expand Down
5 changes: 0 additions & 5 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import torch.nn.functional as F
from torch import nn

from torchtrain.logging_utils import rank0_log


@dataclass
class ModelArgs:
Expand Down Expand Up @@ -476,8 +474,6 @@ def __init__(self, model_args: ModelArgs):

# self.reset_parameters()

rank0_log(f"Model built with: {self.model_args}")

def reset_parameters(
self,
):
Expand All @@ -493,7 +489,6 @@ def reset_parameters(
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)
rank0_log("Model fully initialized via reset_params")

def forward(self, tokens: torch.Tensor):
"""
Expand Down
8 changes: 2 additions & 6 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import logging
from dataclasses import dataclass
from functools import cached_property

from torch.distributed.device_mesh import init_device_mesh

from torchtrain.logging_utils import logger
from torchtrain.parallelisms.parallelize_llama import parallelize_llama

logger = logging.getLogger(__name__)


models_parallelize_fns = {
"llama": parallelize_llama,
}
Expand Down Expand Up @@ -48,8 +44,8 @@ def build_mesh(self, device_type):
if d > 1:
dims.append(d)
names.append(name)
names = tuple(names)
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
return init_device_mesh(device_type, dims, mesh_dim_names=names)

@property
Expand Down
12 changes: 5 additions & 7 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# this file applies the PTD parallelisms and various training techniques to the
# llama model, i.e. activation checkpoint, etc.

import logging
from collections import defaultdict

import torch
Expand All @@ -15,7 +14,6 @@
Replicate,
Shard,
)

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
Expand All @@ -34,11 +32,9 @@
RowwiseParallel,
)
from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import rank0_log
from torchtrain.logging_utils import logger
from torchtrain.meta_init import meta_to_real_init_fn

logger = logging.getLogger(__name__)


def distribute_rmsnorm(module, device_mesh):
# temp sharding API until PTD API is added
Expand Down Expand Up @@ -195,7 +191,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
parallelize_plan=layer_plan,
)

rank0_log("Applied Sequence Parallelism to the model...")
logger.info("Applied Sequence Parallelism to the model")

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
Expand Down Expand Up @@ -228,12 +224,14 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# wrap the rest layers with FSDP
model = wrap(model)

rank0_log("Applied FSDP to the model...")
logger.info("Applied FSDP to the model")
else:
meta_to_real_init_fn(model)
model.cuda()

# we have now moved from meta to device,
# reset parameters for proper initialization
model.reset_parameters()
logger.info("Model fully initialized via reset_parameters")

return model
6 changes: 3 additions & 3 deletions torchtrain/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import rank0_log
from torchtrain.logging_utils import logger


@contextlib.contextmanager
Expand All @@ -31,11 +31,11 @@ def trace_handler(prof):
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir, exist_ok=True)
rank0_log(f"exporting profile traces to {curr_trace_dir}")
logger.info(f"Exporting profile traces to {curr_trace_dir}")

prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")

rank0_log(f"Profiling active. Traces will be saved at {trace_dir}")
logger.info(f"Profiling active. Traces will be saved at {trace_dir}")

if not os.path.exists(trace_dir):
os.makedirs(trace_dir, exist_ok=True)
Expand Down
Loading

0 comments on commit 2369861

Please sign in to comment.