diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 6a3f961f7054..c02ccb6e54ce 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -19,9 +19,30 @@ class TPInferEngine: + """Engine class for tensor parallel inference. + + Args: + model (Module): original model, e.g. huggingface CausalLM + shard_config (ShardConfig): The config for sharding original model + max_batch_size (int): maximum batch size + max_input_len (int): maximum input length of sequence + max_output_len (int): maximum output length of output tokens + dtype (torch.dtype): datatype used to init KV cache space + device (str): device the KV cache of engine to be initialized on + + Examples: + >>> # define model and shard config for your inference + >>> model = ... + >>> generate_kwargs = ... + >>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) + >>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + >>> infer_engine.optimize_model() + >>> outputs = infer_engine.generate(input_ids, **generate_kwargs) + """ def __init__(self, model: nn.Module, + shard_config: ShardConfig, max_batch_size: int, max_input_len: int, max_output_len: int, @@ -29,6 +50,7 @@ def __init__(self, device: str = 'cuda') -> None: self.model = model self.model = self.model.to(device) + self.shard_config = shard_config self.sharded_model = None self.max_batch_size = max_batch_size @@ -57,19 +79,25 @@ def _init_manager(self) -> None: self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num) - def optimize_model(self, config: Optional[Dict[Any, Any]] = None) -> None: - """ Apply shardformer to optimize the model. In future generation, use sharded model instead of original model. """ - tp_size = 1 if config is None else config.get('tp_size', 1) + def optimize_model(self) -> None: + """ + Optimize the original model by sharding with ShardFormer. + In further generation, use the sharded model instead of original model. + """ # NOTE we will change to use an inference config later with additional attrs we want - # tp_size = getattr(config, 'tp_size', 1) - shard_config = ShardConfig(enable_tensor_parallelism=True if tp_size > 1 else False, inference_only=True) - shardformer = ShardFormer(shard_config=shard_config) - self._prepare_with_shard_config(shard_config=shard_config) + assert self.shard_config.inference_only is True + shardformer = ShardFormer(shard_config=self.shard_config) + self._prepare_with_shard_config(shard_config=self.shard_config) self._shard_model_by(shardformer) self.model = None def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: - """ Prepare the engine with a given ShardConfig, or create a default one with tp size 1 """ + """ Prepare the engine with a given ShardConfig. + + Args: + shard_config (ShardConfig): shard config given to specify settings of the engine. + If not provided, a default ShardConfig with tp size 1 will be created. + """ self.tp_size = 1 if shard_config is None: shard_config = ShardConfig( @@ -92,20 +120,30 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) return shard_config def _shard_model_by(self, shardformer: ShardFormer) -> None: - """ Shard the model and store the sharded model by given ShardFormer """ + """ Shard original model by the given ShardFormer and store the sharded model. """ assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \ "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = self.model.__class__.__name__ - assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference." + assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(self.model, inference_only=True) self.sharded_model, _ = shardformer.optimize(self.model, policy) self.sharded_model = self.sharded_model.cuda() - @staticmethod - def _supported_models() -> List[str]: + @property + def supported_models(self) -> List[str]: return _supported_models def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor: + """Generate token sequence. + + Args: + input_tokens: could be one of the following types + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + Returns: + torch.Tensor: The returned sequence is given inputs + generated_tokens. + """ if isinstance(input_tokens, torch.Tensor): input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool)) for t in input_tokens: @@ -115,51 +153,14 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], generate_kwargs.update(max_new_tokens=self.max_output_len) if self.sharded_model is not None: - return self.generate_by_set_infer_state(input_tokens, **generate_kwargs) - - return self.model.generate(input_tokens.get('input_ids'), **generate_kwargs) + return self._generate_by_set_infer_state(input_tokens, **generate_kwargs) - @torch.no_grad() - def generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor: - """ - Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate - - Args: - inputs: should be one of the following types - 1. BatchEncoding or dict (e.g. tokenizer batch_encode) - 2. list of input token ids (e.g. appended result of tokenizer encode) - 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') - """ - - # for testing, always use sharded model - assert self.sharded_model is not None, "sharded model does not exist" - - batch_infer_state = self.prepare_batch_state(input_tokens) - assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit" - - # set BatchInferState for the current batch as attr to model - # NOTE this is not an expectable way to pass BatchInferState during inference - # we might want to rewrite generate function (e.g. generate_by_pass_infer_state) - # and pass BatchInferState via model forward - model = self.sharded_model - if isinstance(model, LlamaForCausalLM): - model = self.sharded_model.model - elif isinstance(model, BloomForCausalLM): - model = self.sharded_model.transformer - setattr(model, 'infer_state', batch_infer_state) - - outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) - - # NOTE In future development, we're going to let the scheduler to handle the cache, - # instead of freeing space explicitly at the end of generation - self.cache_manager.free_all() - - return outputs + return self.model.generate(**input_tokens, **generate_kwargs) def prepare_batch_state(self, inputs) -> BatchInferState: """ Create and prepare BatchInferState used for inference during model forwrad, - by processing each sequence of the given inputs + by processing each sequence of the given inputs. Args: inputs: should be one of the following types @@ -216,7 +217,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda') batch_infer_state = BatchInferState(batch_size, max_len_in_batch) - batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device + batch_infer_state.seq_len = seq_lengths.to('cuda') batch_infer_state.start_loc = seq_start_indexes.to('cuda') batch_infer_state.block_loc = block_loc batch_infer_state.decode_layer_id = 0 @@ -225,42 +226,69 @@ def prepare_batch_state(self, inputs) -> BatchInferState: batch_infer_state.set_cache_manager(self.cache_manager) return batch_infer_state + @torch.no_grad() + def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor: + """ + Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate + + Args: + inputs: should be one of the following types + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + """ + + # for testing, always use sharded model + assert self.sharded_model is not None, "sharded model does not exist" + + batch_infer_state = self.prepare_batch_state(input_tokens) + assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit" + + # set BatchInferState for the current batch as attr to model + # NOTE this is not a preferable way to pass BatchInferState during inference + # we might want to rewrite generate function (e.g. _generate_by_pass_infer_state) + # and pass BatchInferState via model forward + model = self.sharded_model + if isinstance(model, LlamaForCausalLM): + model = self.sharded_model.model + elif isinstance(model, BloomForCausalLM): + model = self.sharded_model.transformer + setattr(model, 'infer_state', batch_infer_state) + + outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) + + # NOTE In future development, we're going to let the scheduler to handle the cache, + # instead of freeing space explicitly at the end of generation + self.cache_manager.free_all() + + return outputs + # TODO might want to implement the func that generates output tokens by passing BatchInferState - # as an arg into model.forward - # requires rewriting model generate and replacing model forward + # as an arg into model.forward. + # It requires rewriting model generate and replacing model forward. @torch.no_grad() - def generate_by_pass_infer_state(self, - input_tokens, - max_out_length: int, - generation_config: Optional[GenerationConfig] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: - # if batch_size >= 4: - # assert self.sharded_model is not None, "sharded model does not exist" - # batch_infer_state = self.prepare_batch_state(input_tokens) - # batch_size = batch_infer_state.batch_size - # assert batch_infer_state.max_len_in_batch <= self.max_input_len - # # record sequences finish status, add early stopping, etc, - # for _ in range(min(max_out_length, self.max_output_len)): - # # ... - # self.sharded_model.forward(..., **model_kwargs) - # else: - # Use original model to generate + def _generate_by_pass_infer_state(self, + input_tokens, + max_out_length: int, + generation_config: Optional[GenerationConfig] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: + raise NotImplementedError("generate by passing BatchInferState is not implemented.") - # NOTE might want to use in rewritten generate method: use after model.forward + # might want to use in rewritten generate method: use after model.forward # BatchInferState is created and kept during generation # after each iter of model forward, we should update BatchInferState - def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: + def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: batch_size = infer_state.batch_size device = infer_state.start_loc.device infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) infer_state.seq_len += 1 - # TODO might want to create a sequence pool - # add a single request/sequence/input text at a time and record its length - # In other words, store the actual length of input tokens representing a single input text + # might want to create a sequence pool + # add a single request/sequence/input text at a time and record its length + # In other words, store the actual length of input tokens representing a single input text # E.g. "Introduce landmarks in Beijing" # => add request # => record token length and other necessary information to be used diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py index c07202ef882b..949e3030603a 100644 --- a/examples/inference/bench_bloom.py +++ b/examples/inference/bench_bloom.py @@ -1,7 +1,6 @@ import os import time -import pytest import torch from transformers import BloomForCausalLM, BloomTokenizerFast @@ -50,9 +49,11 @@ def bench_bloom(test_config): model = model.half() # init TPInferEngine and shard the original model - # To benchmark torch original, comment out lines of optimizing model - infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.optimize_model(test_config) + # To benchmark torch original, comment out the line of optimizing model + shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, + inference_only=True) + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.optimize_model() # prepare data for generation batch_size = MAX_BATCH_SIZE @@ -88,7 +89,6 @@ def check_bloom(rank, world_size, port): bench_bloom() -@pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_bloom(): diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index c1ece952b099..6ed4ff8d24af 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -1,8 +1,6 @@ import os import time -import numpy as np -import pytest import torch import torch.distributed as dist from torch.profiler import ProfilerActivity, profile, record_function @@ -77,8 +75,10 @@ def run_llama_test(test_config): model_config = model.config - infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.optimize_model(test_config) + shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, + inference_only=True) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.optimize_model() batch_size = 2 max_new_tokens = 128 @@ -111,7 +111,7 @@ def run_llama_test(test_config): with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: with record_function("model_inference"): torch.cuda.synchronize() - outputs = infer_engine.generate(input_tokens, generate_kwargs) + outputs = infer_engine.generate(input_tokens, **generate_kwargs) torch.cuda.synchronize() print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) @@ -122,7 +122,6 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_llama(): diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index eb55d7d40778..f26f05abeb79 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -9,7 +9,9 @@ import colossalai from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo TP_SIZE = 2 MAX_BATCH_SIZE = 4 @@ -23,37 +25,25 @@ 'tp_size': TP_SIZE, }]) def run(test_config): - model_path = "/data3/models/bloom-7b1" - if os.path.isdir(model_path) is False: - return - tokenizer = AutoTokenizer.from_pretrained(model_path) - tokenizer.pad_token = tokenizer.eos_token + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom_for_causal_lm') + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): + orig_model = model_fn() + orig_model = orig_model.half() + data = data_gen_fn() - text1 = "Introduce some landmarks in Beijing" - text2 = "how is weather today?" - input_ids = tokenizer.batch_encode_plus([text1, text2], return_tensors='pt', padding=True) + shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, + inference_only=True) + infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.optimize_model() - model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) - model = model.half() + generate_kwargs = dict(do_sample=False) + outputs = infer_engine.generate(data, **generate_kwargs) - infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.optimize_model(test_config) + assert outputs is not None - generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate(input_ids, **generate_kwargs) - assert outputs is not None - - if not dist.is_initialized() or dist.get_rank() == 0: - # output_text = tokenizer.decode(outputs[0]) - # print(output_text) - for o in outputs: - output_text = tokenizer.decode(o) - # print(output_text) - - -def check_engine(rank, world_size, port): +def check_bloom(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run() @@ -63,9 +53,9 @@ def check_engine(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() -def test_engine_infer(): - spawn(check_engine, TP_SIZE) +def test_bloom_infer(): + spawn(check_bloom, TP_SIZE) if __name__ == '__main__': - test_engine_infer() + test_bloom_infer() diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index b4feb10c4573..b1b3b57068c1 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -44,7 +44,8 @@ def __init__(self): dummy_config = DummyModelConfig() model = DummyModule(dummy_config) - infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + shard_config = ShardConfig(enable_tensor_parallelism=False) + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970], [80540, 15473, 3331, 11970], [80540, 15473]] @@ -86,7 +87,7 @@ def test_orig_generate(): shard_config = ShardConfig(enable_tensor_parallelism=False) # init TPInferEngine - infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) # original model generate generate_kwargs = dict(do_sample=False) @@ -104,8 +105,10 @@ def run(test_config): model = model.half() model.to(torch.cuda.current_device()) - infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.optimize_model(test_config) + shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, + inference_only=True) + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.optimize_model() assert infer_engine.cache_manager is not None assert infer_engine.tp_size == TP_SIZE diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 3b9317cbceb6..7dfb63e16e8e 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -1,6 +1,6 @@ import os +import warnings -import numpy as np import pytest import torch import torch.distributed as dist @@ -8,11 +8,11 @@ from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai -from colossalai.cluster import ProcessGroupMesh from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 2 @@ -51,31 +51,22 @@ def init_to_get_rotary(self, base=10000): }]) def run_llama_test(test_config): - llama_model_path = "/data/scratch/llama-7b-hf" - if os.path.isdir(llama_model_path) is False: - return + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama_for_casual_lm') + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): + orig_model = model_fn() + init_to_get_rotary(orig_model.model, base=10000) + orig_model = orig_model.half() + data = data_gen_fn() - tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) - tokenizer.pad_token_id = tokenizer.unk_token_id - model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) - init_to_get_rotary(model.model, base=10000) - model = model.half() + shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, + inference_only=True) + infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.optimize_model() - text = ["how is weather today?", "i am "] - input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') + generate_kwargs = dict(do_sample=False) + outputs = infer_engine.generate(data, **generate_kwargs) - infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.optimize_model(test_config) - - generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - outputs = infer_engine.generate(input_ids, **generate_kwargs) - - assert outputs is not None - - if not dist.is_initialized() or dist.get_rank() == 0: - for o in outputs: - output_text = tokenizer.decode(o) - # print(output_text) + assert outputs is not None def check_llama(rank, world_size, port):