diff --git a/scripts/inference/bloom_accelerate_inference.py b/scripts/inference/bloom_accelerate_inference.py index 6b58c01d5..8fc5a6120 100644 --- a/scripts/inference/bloom_accelerate_inference.py +++ b/scripts/inference/bloom_accelerate_inference.py @@ -1,4 +1,4 @@ -import gc +from argparse import Namespace from typing import List, Union import torch @@ -6,28 +6,25 @@ import utils from utils import ( - Execute, - benchmark_generation, + Model, + benchmark_end_to_end, get_argument_parser, - get_benchmark_results, - get_dummy_batch, - print_rank_n, - run_and_log_time, + print_rank_n ) -class HFAccelerateModel: - def __init__(self, model_name: str, dtype: torch.dtype) -> None: +class HFAccelerateModel(Model): + def __init__(self, args: Namespace) -> None: print_rank_n("Loading model...") - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(args.model_name) self.model = AutoModelForCausalLM.from_pretrained( - model_name, + args.model_name, device_map="auto", max_memory=get_max_memory_per_gpu_dict( - dtype, model_name), - torch_dtype=dtype + args.dtype, args.model_name), + torch_dtype=args.dtype ) self.model.eval() @@ -126,84 +123,5 @@ def get_max_memory_per_gpu_dict(dtype, model_name): return {i: param_memory_per_gpu_in_bytes for i in range(torch.cuda.device_count())} -def main(): - args = get_args() - - model, initialization_time = run_and_log_time( - Execute( - HFAccelerateModel, - { - "model_name": args.model_name, - "dtype": args.dtype, - } - ) - ) - - if (args.generate_kwargs): - generate_kwargs = args.generate_kwargs - else: - generate_kwargs = { - "max_new_tokens": 100, - "do_sample": False - } - - print_rank_n( - f"*** Starting to generate {generate_kwargs['max_new_tokens']} tokens with bs={args.batch_size}") - - input_sentences = get_dummy_batch(args.batch_size) - - print_rank_n(f"Generate args {generate_kwargs}") - - # warmup is a must if measuring speed as it's when all the optimizations are performed - # e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs - model.generate( - input_sentences, - generate_kwargs - ) - - (output_text, num_generated_tokens), generation_time = run_and_log_time( - Execute( - model.generate, - { - "text": input_sentences, - "generate_kwargs": generate_kwargs - } - ) - ) - for i, (o, _) in zip(input_sentences, zip(output_text, num_generated_tokens)): - print_rank_n(f"{'-' * 60}\nin = {i}\nout = {o}\n") - - if (args.benchmark_cycles > 0): - print_rank_n(f"*** Running benchmark") - - torch.cuda.empty_cache() - gc.collect() - - # warm up - model.generate(input_sentences, generate_kwargs) - torch.cuda.synchronize() - - # benchmark - total_new_tokens_generated, benchmark_time = run_and_log_time( - Execute( - benchmark_generation, - { - "input_sentences": input_sentences, - "model": model, - "generate_kwargs": generate_kwargs - } - ) - ) - print_rank_n( - get_benchmark_results( - benchmark_time, - initialization_time, - generation_time, - total_new_tokens_generated, - args.batch_size - ) - ) - - if (__name__ == "__main__"): - main() + benchmark_end_to_end(get_args(), HFAccelerateModel) diff --git a/scripts/inference/utils.py b/scripts/inference/utils.py index f6b7e0345..777f1a3ba 100644 --- a/scripts/inference/utils.py +++ b/scripts/inference/utils.py @@ -1,12 +1,12 @@ import argparse import copy +import gc import math import time -from typing import Any, List, Tuple, Union +from typing import Any, List, Union import torch import torch.distributed as dist -from transformers import AutoModelForCausalLM, AutoTokenizer dummy_input_sentences = [ @@ -36,6 +36,17 @@ def __call__(self) -> Any: return self.func(**self.kwargs) +class Model: + def __init__(self, args: argparse.Namespace) -> None: + raise NotImplementedError("This is a dummy class") + + def generate(self, + text: Union[str, List[str]], + generate_kwargs: dict, + remove_input_from_output: bool = False) -> Union[str, List[str]]: + raise NotImplementedError("This is a dummy class") + + def get_argument_parser(): parser = argparse.ArgumentParser() @@ -147,3 +158,74 @@ def get_benchmark_results(benchmark_time: float, Generation time per batch = {generation_time:.2f} secs Model loading time + generation time per batch = {initialization_time + generation_time:.2f} secs """ + + +def benchmark_end_to_end(args: argparse.Namespace, model_class: Model) -> None: + model, initialization_time = run_and_log_time( + Execute(model_class, {"args": args}) + ) + + if (args.generate_kwargs): + generate_kwargs = args.generate_kwargs + else: + generate_kwargs = { + "max_new_tokens": 100, + "do_sample": False + } + + print_rank_n( + f"*** Starting to generate {generate_kwargs['max_new_tokens']} tokens with bs={args.batch_size}") + + input_sentences = get_dummy_batch(args.batch_size) + + print_rank_n(f"Generate args {generate_kwargs}") + + # warmup is a must if measuring speed as it's when all the optimizations are performed + # e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs + model.generate( + input_sentences, + generate_kwargs + ) + + (output_text, num_generated_tokens), generation_time = run_and_log_time( + Execute( + model.generate, + { + "text": input_sentences, + "generate_kwargs": generate_kwargs + } + ) + ) + for i, (o, _) in zip(input_sentences, zip(output_text, num_generated_tokens)): + print_rank_n(f"{'-' * 60}\nin = {i}\nout = {o}\n") + + if (args.benchmark_cycles > 0): + print_rank_n(f"*** Running benchmark") + + torch.cuda.empty_cache() + gc.collect() + + # warm up + model.generate(input_sentences, generate_kwargs) + torch.cuda.synchronize() + + # benchmark + total_new_tokens_generated, benchmark_time = run_and_log_time( + Execute( + benchmark_generation, + { + "input_sentences": input_sentences, + "model": model, + "generate_kwargs": generate_kwargs + } + ) + ) + print_rank_n( + get_benchmark_results( + benchmark_time, + initialization_time, + generation_time, + total_new_tokens_generated, + args.batch_size + ) + )