Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve logging #132

Merged
merged 2 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions torchtrain/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
set_model_state_dict,
set_optimizer_state_dict,
)
from torchtrain.logging_utils import rank0_log
from torchtrain.logging_utils import logger


class IntervalType(enum.Enum):
Expand Down Expand Up @@ -109,13 +109,13 @@ def save(self, curr_step: int, force: bool = False) -> None:
self.work = None
self.doit = None

rank0_log(f"Saving a checkpoint in step {curr_step}.")
logger.info(f"Saving a checkpoint at step {curr_step}")
begin = time.monotonic()
dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step))
self.reset()
rank0_log(
f"Finish saving the checkpoint in step {curr_step}. "
f"{time.monotonic() - begin} seconds"
logger.info(
f"Finished saving the checkpoint at step {curr_step} "
f"in {time.monotonic() - begin} seconds"
)

def load(self, step: int = -1) -> bool:
Expand All @@ -136,11 +136,13 @@ def load(self, step: int = -1) -> bool:
return False
step = max(step_counts)

rank0_log("Loading a checkpoint.")
logger.info("Loading a checkpoint")
begin = time.monotonic()
dcp.load(
self.states,
checkpoint_id=self.create_checkpoint_id(step),
)
rank0_log(f"Finish loading a checkpoint. {time.monotonic() - begin} seconds.")
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin} seconds"
)
return True
27 changes: 9 additions & 18 deletions torchtrain/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from torch.utils.data import DataLoader, IterableDataset

from torchtrain.datasets.tokenizer import TokenizerIf
from torchtrain.logging_utils import rank0_log
from torchtrain.utils import Color
from torchtrain.logging_utils import logger

from datasets import load_dataset, load_from_disk
from datasets.distributed import split_dataset_by_node
Expand Down Expand Up @@ -86,21 +85,17 @@ def __init__(
) -> None:
if dataset_name not in _supported_datasets:
raise ValueError(
f"Dataset {dataset_name} is not supported. Supported datasets are: {_supported_datasets.keys()}"
f"Dataset {dataset_name} is not supported. Supported datasets are: {_supported_datasets.keys()}."
)

# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
if dataset_path:
rank0_log(
f"{Color.green}Loading '{dataset_name}' dataset locally from {dataset_path}...{Color.reset}"
)
logger.info(f"Loading {dataset_name} dataset locally from {dataset_path}")
ds = load_from_disk(dataset_path)
else:
rank0_log(
f"{Color.green}Preparing '{dataset_name}' dataset from HuggingFace...{Color.reset}"
)
# Setting `streaming=True` works for large dataset, but the speed is slow.
logger.info(f"Preparing {dataset_name} dataset from HuggingFace")
# Setting `streaming=True` works for large dataset, but is slightly slower and unstable.
# c4 is huge, and requires both streaming and language selection (we default to en)
if dataset_name == "c4":
ds = load_dataset(
Expand Down Expand Up @@ -136,16 +131,12 @@ def __iter__(self):
label = x[1:]
yield input, label
if not self.infinite:
rank0_log(
f"{Color.red}WARNING:{Color.reset} dataset {Color.yellow}'{self.dataset_name}'{Color.reset} has "
f"run out of data.{Color.reset}"
)
logger.warning(f"Dataset {self.dataset_name} has run out of data.")
break
else:
# we are re-looping on the same dataset, warn user
rank0_log(
f"{Color.red}WARNING:{Color.reset} dataset {Color.yellow}'{self.dataset_name}'{Color.reset} is "
f"being re-looped. Loss related metrics might be misleading.{Color.reset}"
logger.warning(
f"Dataset {self.dataset_name} is being re-looped. "
"Loss related metrics might be misleading."
)


Expand Down
22 changes: 10 additions & 12 deletions torchtrain/datasets/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -8,13 +8,11 @@

import os
from abc import ABC, abstractmethod
from logging import getLogger
from typing import List

from sentencepiece import SentencePieceProcessor


logger = getLogger()
from torchtrain.logging_utils import logger


class TokenizerIf(ABC):
Expand All @@ -40,34 +38,34 @@ def n_words(self) -> int:


def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> TokenizerIf:
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
if tokenizer_type == "sentencepiece":
return SentencePieceTokenizer(tokenizer_path)
else:
raise ValueError(f"Unknown tokenizer type: {args.type}")


class SentencePieceTokenizer(TokenizerIf):
"""tokenizing and encoding/decoding text using SentencePiece."""
"""
Tokenizing and encoding/decoding text based on a SentencePiece model.

Args:
tokenizer_path (str): The path to the SentencePiece model file.
"""

def __init__(self, tokenizer_path: str):
"""
Initializes the Tokenizer with a SentencePiece model.

Args:
tokenizer_path (str): The path to the SentencePiece model file.
"""
super().__init__(tokenizer_path)
# reload tokenizer
self.sp_model = SentencePieceProcessor(model_file=tokenizer_path)
logger.info(f"Reloaded SentencePiece model from {tokenizer_path}")

# BOS / EOS token IDs
self._n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
f"SentencePieceTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}"
)
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

Expand Down
5 changes: 2 additions & 3 deletions torchtrain/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
# All rights reserved

from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import rank0_log
from torchtrain.logging_utils import logger
from torchtrain.models.llama import Transformer
from torchtrain.utils import Color


def build_fp8_linear(model: Transformer, job_config: JobConfig):
Expand Down Expand Up @@ -42,4 +41,4 @@ def build_fp8_linear(model: Transformer, job_config: JobConfig):

# Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type
swap_linear_with_float8_linear(model, float8_linear_type)
rank0_log(f"{Color.green}Using {linear_type} float8 linear layers{Color.reset}")
logger.info(f"Swapped to {linear_type} float8 linear layers")
9 changes: 3 additions & 6 deletions torchtrain/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import logging

import torch

logger = logging.getLogger()


def rank0_log(msg):
if torch.distributed.get_rank() == 0:
logger.info(msg)


def init_logger():
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
Expand Down
17 changes: 8 additions & 9 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torchtrain.config_manager import JobConfig

from torchtrain.logging_utils import rank0_log
from torchtrain.logging_utils import logger

# note that GiB (gibibyte) is 1024, vs GB is 1000
_gib_in_bytes = 1024 * 1024 * 1024
Expand Down Expand Up @@ -122,18 +121,18 @@ def get_current_stats(self, return_data: bool = False):
self.device_active_memory_usage, self.device_capacity, precision=2
)

display_str = ""
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%, "
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n"
display_str = f"{self.device_name} ({self.device_index}). "
display_str += f"Current memory: reserved {self.device_reserved_memory_pct}%, "
display_str += f"alloc {self.device_alloc_memory_pct}%, active {self.device_active_memory_pct}%. "

self.get_peak_stats(curr_mem)

peak_active_pct = self.get_pct_memory(self.peak_active_memory)
peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory)
peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory)
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n"
display_str += f"Peak memory: reserved {peak_reserved_pct}%, alloc {peak_allocated_pct}%, active {peak_active_pct}%. "

display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}"
display_str += f"Num retries: {self.num_retries}. Num ooms: {self.num_ooms}."
if self.num_retries > 0:
display_str += f"\nWARNING: {self.num_retries} retries -- recommend lowering batch size for max performance\n"

Expand Down Expand Up @@ -224,8 +223,8 @@ def build_metric_logger(config: JobConfig, tag: Optional[str] = None):

enable_tb = config.metrics.enable_tensorboard
if enable_tb:
rank0_log(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}."
logger.info(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}"
)

rank_str = f"rank_{torch.distributed.get_rank()}"
Expand Down
5 changes: 0 additions & 5 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import torch.nn.functional as F
from torch import nn

from torchtrain.logging_utils import rank0_log


@dataclass
class ModelArgs:
Expand Down Expand Up @@ -476,8 +474,6 @@ def __init__(self, model_args: ModelArgs):

# self.reset_parameters()

rank0_log(f"Model built with: {self.model_args}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while I think the output of this line could be improved, it's still important the user sees the model details of what they are building to double check and avoid negative surprises.
i.e. I would not want to remove this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely didn't want to remove this info. Actually it was moved to train.py L133-L135, because I think we don't really need to do printing in model.py. Thanks for the review! @lessw2020


def reset_parameters(
self,
):
Expand All @@ -493,7 +489,6 @@ def reset_parameters(
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)
rank0_log("Model fully initialized via reset_params")

def forward(self, tokens: torch.Tensor):
"""
Expand Down
8 changes: 2 additions & 6 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import logging
from dataclasses import dataclass
from functools import cached_property

from torch.distributed.device_mesh import init_device_mesh

from torchtrain.logging_utils import logger
from torchtrain.parallelisms.parallelize_llama import parallelize_llama

logger = logging.getLogger(__name__)


models_parallelize_fns = {
"llama": parallelize_llama,
}
Expand Down Expand Up @@ -48,8 +44,8 @@ def build_mesh(self, device_type):
if d > 1:
dims.append(d)
names.append(name)
names = tuple(names)
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
return init_device_mesh(device_type, dims, mesh_dim_names=names)

@property
Expand Down
12 changes: 5 additions & 7 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# this file applies the PTD parallelisms and various training techniques to the
# llama model, i.e. activation checkpoint, etc.

import logging
from collections import defaultdict

import torch
Expand All @@ -15,7 +14,6 @@
Replicate,
Shard,
)

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
Expand All @@ -34,11 +32,9 @@
RowwiseParallel,
)
from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import rank0_log
from torchtrain.logging_utils import logger
from torchtrain.meta_init import meta_to_real_init_fn

logger = logging.getLogger(__name__)


def distribute_rmsnorm(module, device_mesh):
# temp sharding API until PTD API is added
Expand Down Expand Up @@ -195,7 +191,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
parallelize_plan=layer_plan,
)

rank0_log("Applied Sequence Parallelism to the model...")
logger.info("Applied Sequence Parallelism to the model")

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
Expand Down Expand Up @@ -228,12 +224,14 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# wrap the rest layers with FSDP
model = wrap(model)

rank0_log("Applied FSDP to the model...")
logger.info("Applied FSDP to the model")
else:
meta_to_real_init_fn(model)
model.cuda()

# we have now moved from meta to device,
# reset parameters for proper initialization
model.reset_parameters()
logger.info("Model fully initialized via reset_parameters")

return model
6 changes: 3 additions & 3 deletions torchtrain/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import rank0_log
from torchtrain.logging_utils import logger


@contextlib.contextmanager
Expand All @@ -31,11 +31,11 @@ def trace_handler(prof):
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir, exist_ok=True)
rank0_log(f"exporting profile traces to {curr_trace_dir}")
logger.info(f"Exporting profile traces to {curr_trace_dir}")

prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")

rank0_log(f"Profiling active. Traces will be saved at {trace_dir}")
logger.info(f"Profiling active. Traces will be saved at {trace_dir}")

if not os.path.exists(trace_dir):
os.makedirs(trace_dir, exist_ok=True)
Expand Down
Loading
Loading