Skip to content

Commit

Permalink
integrate chat tokenizer and add llama3-8B model option (pytorch#1110)
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 authored Sep 6, 2024
1 parent d58923e commit 6de408c
Showing 1 changed file with 52 additions and 10 deletions.
62 changes: 52 additions & 10 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,44 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict

# Run command:
# torchrun --nproc-per-node 4 dist_run.py
import torch
import torch.distributed as dist
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe

from distributed.verification_utils import find_cpu_tensors
from distributed.logging_utils import setup_logging

# TODO - these are not distributed specific, consider moving to new package
from distributed.safetensor_utils import (
get_hf_config_file,
get_hf_weight_map_and_path,
load_safetensor_weights,
)
from distributed.safetensor_utils import (get_hf_config_file,
get_hf_weight_map_and_path,
load_safetensor_weights)
from distributed.utils import Color as color
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from distributed.verification_utils import find_cpu_tensors
from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer
from torchchat.model import ModelArgs, Transformer
from torchchat.utils.build_utils import set_precision

try:
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer
except ImportError:
TiktokenTokenizer = None
try:
from sentencepiece import SentencePieceProcessor
except ImportError:
SentencePieceProcessor = None


logger = setup_logging(__name__)

MODEL_NAME = "Transformer-2-7b-chat-hf"
NAME_TO_HF_MODEL_ID_AND_DTYPE = {
"Transformer-2-7b-chat-hf": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
"Meta-Llama-3-8B": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
}
CACHE_PRECISION = torch.bfloat16

Expand All @@ -45,6 +59,33 @@ def _create_device_mesh(mesh_dimensions):
return dist.init_device_mesh("cuda", mesh_dimensions, mesh_dim_names=("pp", "tp"))


def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
return SimpleNamespace(**dictionary)


def _build_chat_tokenizer(
model_base_name: str = "llama3",
) -> SentencePieceProcessor | TiktokenTokenizer:
# Create base args for tokenizer
default_model_dir = Path(
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
).expanduser()

tokenconfig = {
"model_directory": default_model_dir,
"model": model_base_name,
"tokenizer_path": None,
}
args = dict_to_args(tokenconfig)
tokenizer_args = TokenizerArgs.from_args(args)
tokenizer = _initialize_tokenizer(tokenizer_args)
assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}"
logger.info(
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
)
return tokenizer


def _load_model_weights(stage_module, hf_model_name, device, model_config):
"""Load the weights from the safetensor file(s) into the model stage.
Model config is needed b/c we permute wq and wk weights based on attn heads.
Expand Down Expand Up @@ -77,8 +118,9 @@ def main():

config = ModelArgs.from_name(MODEL_NAME).text_transformer_args
logger.info(f"Chat Model Config: {config}")
# TODO - should we make this work...atm returns float32
# torchchat_precision = get_precision()

tokenizer = _build_chat_tokenizer()
logger.info(f"built tokenizer {tokenizer=}")

hf_model_name, model_dtype = NAME_TO_HF_MODEL_ID_AND_DTYPE[MODEL_NAME]
logger.info(f"Using HF model weights from {hf_model_name} and dtype {model_dtype}")
Expand Down

0 comments on commit 6de408c

Please sign in to comment.