From 677cbfacf8ef11f423ec1f5216083675615ab85d Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Thu, 30 May 2024 13:48:46 +0800 Subject: [PATCH] [Fix/Example] Fix Llama Inference Loading Data Type (#5763) * [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3 --- examples/inference/llama/benchmark_llama3.py | 12 +++++++++++- examples/inference/llama/llama_generation.py | 9 ++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/examples/inference/llama/benchmark_llama3.py b/examples/inference/llama/benchmark_llama3.py index 07ebdb2b1bfb..76d9c6a42000 100644 --- a/examples/inference/llama/benchmark_llama3.py +++ b/examples/inference/llama/benchmark_llama3.py @@ -17,6 +17,13 @@ MEGABYTE = 1024**2 N_WARMUP_STEPS = 2 +TORCH_DTYPE_MAP = { + "fp16": torch.float16, + "fp32": torch.float32, + "bf16": torch.bfloat16, +} + + CONFIG_MAP = { "toy": transformers.LlamaConfig(num_hidden_layers=4), "llama-7b": transformers.LlamaConfig( @@ -104,10 +111,13 @@ def print_details_info(model_config, whole_end2end, total_token_num, dtype, coor def benchmark_inference(args): coordinator = DistCoordinator() + torch_dtype = TORCH_DTYPE_MAP.get(args.dtype, None) config = CONFIG_MAP[args.model] + config.torch_dtype = torch_dtype config.pad_token_id = config.eos_token_id + if args.model_path is not None: - model = transformers.LlamaForCausalLM.from_pretrained(args.model_path) + model = transformers.LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch_dtype) tokenizer = AutoTokenizer.from_pretrained(args.model_path) else: # Random weights diff --git a/examples/inference/llama/llama_generation.py b/examples/inference/llama/llama_generation.py index c0a1a585a1b9..9326f717cc00 100644 --- a/examples/inference/llama/llama_generation.py +++ b/examples/inference/llama/llama_generation.py @@ -1,5 +1,6 @@ import argparse +from torch import bfloat16, float16, float32 from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import colossalai @@ -12,6 +13,12 @@ MODEL_CLS = AutoModelForCausalLM POLICY_CLS = NoPaddingLlamaModelInferPolicy +TORCH_DTYPE_MAP = { + "fp16": float16, + "fp32": float32, + "bf16": bfloat16, +} + def infer(args): # ============================== @@ -24,7 +31,7 @@ def infer(args): # Load model and tokenizer # ============================== model_path_or_name = args.model - model = MODEL_CLS.from_pretrained(model_path_or_name) + model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None)) tokenizer = AutoTokenizer.from_pretrained(model_path_or_name) tokenizer.pad_token = tokenizer.eos_token # coordinator.print_on_master(f"Model Config:\n{model.config}")