Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Infer/nfc] Add args for inference benchmarks #4674

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions examples/inference/bench_bloom.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import argparse
import os
import time

import torch
from transformers import BloomForCausalLM, BloomTokenizerFast

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
TPSIZE = 1
MAX_BATCH_SIZE = 32
MAX_INPUT_LEN = 1024
MAX_OUTPUT_LEN = 128


def print_perf_stats(latency_set, config, bs, warmup=3):
Expand All @@ -37,31 +33,28 @@ def print_perf_stats(latency_set, config, bs, warmup=3):
print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))


@parameterize('test_config', [{
'tp_size': TPSIZE,
}])
def bench_bloom(test_config):
def bench_bloom(args):
model_path = args.path
max_batch_size = args.batch_size
max_input_len = args.input_len
max_output_len = args.output_len

model_path = "/home/lczyh/data3/models/bloom-7b1"
tokenizer = BloomTokenizerFast.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half()

# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
infer_engine.optimize_model()

# prepare data for generation
batch_size = MAX_BATCH_SIZE
input_len = MAX_INPUT_LEN
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(10, 1000, (batch_size, input_len)),
"attention_mask": torch.ones((batch_size, input_len))
"input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
"attention_mask": torch.ones((max_batch_size, max_input_len))
}
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
Expand All @@ -78,22 +71,31 @@ def bench_bloom(test_config):
end = time.time()
out_len = outputs.shape[1]
print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
times.append((end - start) / (out_len - input_len))
times.append((end - start) / (out_len - max_input_len))

print_perf_stats(times, model.config, batch_size)
print_perf_stats(times, model.config, max_batch_size)


def check_bloom(rank, world_size, port):
def check_bloom(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
bench_bloom()
bench_bloom(args)


@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
spawn(check_bloom, TPSIZE)
def test_bloom(args):
spawn(check_bloom, args.tp_size, args=args)


if __name__ == "__main__":
test_bloom()
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')

args = parser.parse_args()

test_bloom(args)
65 changes: 31 additions & 34 deletions examples/inference/bench_llama.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
import argparse
import os
import time

import torch
import torch.distributed as dist
from torch.profiler import ProfilerActivity, profile, record_function
from transformers import LlamaForCausalLM, LlamaTokenizer

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
TPSIZE = 1
BATCH_SIZE = 32
MAX_INPUT_LEN = 1024
MAX_OUTPUT_LEN = 256


def init_to_get_rotary(self, base=10000):
Expand All @@ -43,7 +38,7 @@ def init_to_get_rotary(self, base=10000):
return


def print_perf_stats(latency_set, config, warmup=3):
def print_perf_stats(latency_set, config, bs, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
latency_set = latency_set[warmup:]
Expand All @@ -58,15 +53,15 @@ def print_perf_stats(latency_set, config, warmup=3):

print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * BATCH_SIZE / 1e12))
print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))


@parameterize('test_config', [{
'tp_size': TPSIZE,
}])
def run_llama_test(test_config):
def run_llama_test(args):
llama_model_path = args.path
max_batch_size = args.batch_size
max_input_len = args.input_len
max_output_len = args.output_len

llama_model_path = "/data/scratch/llama-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
Expand All @@ -75,19 +70,14 @@ def run_llama_test(test_config):

model_config = model.config

shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
infer_engine.optimize_model()

batch_size = 2
max_new_tokens = 128
input_len = 1024

generate_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(1, 1000, (batch_size, input_len), device='cuda'),
"attention_mask": torch.ones((batch_size, input_len), device='cuda')
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
}

iters = 10
Expand All @@ -101,12 +91,10 @@ def run_llama_test(test_config):
end = time.time()
out_len = outputs.shape[1]
print("generation time {} s".format(str(end - start)))
times.append((end - start) / (out_len - input_len))
times.append((end - start) / (out_len - max_input_len))

print("outputs, ", len(outputs))
outputs = tokenizer.batch_decode(outputs)

print_perf_stats(times, model_config)
print_perf_stats(times, model_config, max_batch_size)

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
Expand All @@ -116,17 +104,26 @@ def run_llama_test(test_config):
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


def check_llama(rank, world_size, port):
def check_llama(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test()
run_llama_test(args)


@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, TPSIZE)
def test_llama(args):
spawn(check_llama, args.tp_size, args=args)


if __name__ == "__main__":
test_llama()
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')

args = parser.parse_args()

test_llama(args)
Loading