Skip to content

Commit

Permalink
[Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670)
Browse files Browse the repository at this point in the history
* fix engine funcs

* TPInferEngine: receive shard config in init

* benchmarks: revise TPInferEngine init

* benchmarks: remove pytest decorator

* trivial fix

* use small model for tests
  • Loading branch information
yuanheng-zhao authored Sep 8, 2023
1 parent 2a98d75 commit e2e96d4
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 146 deletions.
184 changes: 106 additions & 78 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,38 @@


class TPInferEngine:
"""Engine class for tensor parallel inference.
Args:
model (Module): original model, e.g. huggingface CausalLM
shard_config (ShardConfig): The config for sharding original model
max_batch_size (int): maximum batch size
max_input_len (int): maximum input length of sequence
max_output_len (int): maximum output length of output tokens
dtype (torch.dtype): datatype used to init KV cache space
device (str): device the KV cache of engine to be initialized on
Examples:
>>> # define model and shard config for your inference
>>> model = ...
>>> generate_kwargs = ...
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
>>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
>>> infer_engine.optimize_model()
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
"""

def __init__(self,
model: nn.Module,
shard_config: ShardConfig,
max_batch_size: int,
max_input_len: int,
max_output_len: int,
dtype: torch.dtype = torch.float16,
device: str = 'cuda') -> None:
self.model = model
self.model = self.model.to(device)
self.shard_config = shard_config
self.sharded_model = None

self.max_batch_size = max_batch_size
Expand Down Expand Up @@ -57,19 +79,25 @@ def _init_manager(self) -> None:
self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
self.layer_num)

def optimize_model(self, config: Optional[Dict[Any, Any]] = None) -> None:
""" Apply shardformer to optimize the model. In future generation, use sharded model instead of original model. """
tp_size = 1 if config is None else config.get('tp_size', 1)
def optimize_model(self) -> None:
"""
Optimize the original model by sharding with ShardFormer.
In further generation, use the sharded model instead of original model.
"""
# NOTE we will change to use an inference config later with additional attrs we want
# tp_size = getattr(config, 'tp_size', 1)
shard_config = ShardConfig(enable_tensor_parallelism=True if tp_size > 1 else False, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)
self._prepare_with_shard_config(shard_config=shard_config)
assert self.shard_config.inference_only is True
shardformer = ShardFormer(shard_config=self.shard_config)
self._prepare_with_shard_config(shard_config=self.shard_config)
self._shard_model_by(shardformer)
self.model = None

def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
""" Prepare the engine with a given ShardConfig, or create a default one with tp size 1 """
""" Prepare the engine with a given ShardConfig.
Args:
shard_config (ShardConfig): shard config given to specify settings of the engine.
If not provided, a default ShardConfig with tp size 1 will be created.
"""
self.tp_size = 1
if shard_config is None:
shard_config = ShardConfig(
Expand All @@ -92,20 +120,30 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None)
return shard_config

def _shard_model_by(self, shardformer: ShardFormer) -> None:
""" Shard the model and store the sharded model by given ShardFormer """
""" Shard original model by the given ShardFormer and store the sharded model. """
assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
"Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = self.model.__class__.__name__
assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference."
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(self.model, inference_only=True)
self.sharded_model, _ = shardformer.optimize(self.model, policy)
self.sharded_model = self.sharded_model.cuda()

@staticmethod
def _supported_models() -> List[str]:
@property
def supported_models(self) -> List[str]:
return _supported_models

def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor:
"""Generate token sequence.
Args:
input_tokens: could be one of the following types
1. BatchEncoding or dict (e.g. tokenizer batch_encode)
2. list of input token ids (e.g. appended result of tokenizer encode)
3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
Returns:
torch.Tensor: The returned sequence is given inputs + generated_tokens.
"""
if isinstance(input_tokens, torch.Tensor):
input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool))
for t in input_tokens:
Expand All @@ -115,51 +153,14 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor],
generate_kwargs.update(max_new_tokens=self.max_output_len)

if self.sharded_model is not None:
return self.generate_by_set_infer_state(input_tokens, **generate_kwargs)

return self.model.generate(input_tokens.get('input_ids'), **generate_kwargs)
return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)

@torch.no_grad()
def generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor:
"""
Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate
Args:
inputs: should be one of the following types
1. BatchEncoding or dict (e.g. tokenizer batch_encode)
2. list of input token ids (e.g. appended result of tokenizer encode)
3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
"""

# for testing, always use sharded model
assert self.sharded_model is not None, "sharded model does not exist"

batch_infer_state = self.prepare_batch_state(input_tokens)
assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit"

# set BatchInferState for the current batch as attr to model
# NOTE this is not an expectable way to pass BatchInferState during inference
# we might want to rewrite generate function (e.g. generate_by_pass_infer_state)
# and pass BatchInferState via model forward
model = self.sharded_model
if isinstance(model, LlamaForCausalLM):
model = self.sharded_model.model
elif isinstance(model, BloomForCausalLM):
model = self.sharded_model.transformer
setattr(model, 'infer_state', batch_infer_state)

outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)

# NOTE In future development, we're going to let the scheduler to handle the cache,
# instead of freeing space explicitly at the end of generation
self.cache_manager.free_all()

return outputs
return self.model.generate(**input_tokens, **generate_kwargs)

def prepare_batch_state(self, inputs) -> BatchInferState:
"""
Create and prepare BatchInferState used for inference during model forwrad,
by processing each sequence of the given inputs
by processing each sequence of the given inputs.
Args:
inputs: should be one of the following types
Expand Down Expand Up @@ -216,7 +217,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device
batch_infer_state.seq_len = seq_lengths.to('cuda')
batch_infer_state.start_loc = seq_start_indexes.to('cuda')
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
Expand All @@ -225,42 +226,69 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
batch_infer_state.set_cache_manager(self.cache_manager)
return batch_infer_state

@torch.no_grad()
def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor:
"""
Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate
Args:
inputs: should be one of the following types
1. BatchEncoding or dict (e.g. tokenizer batch_encode)
2. list of input token ids (e.g. appended result of tokenizer encode)
3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
"""

# for testing, always use sharded model
assert self.sharded_model is not None, "sharded model does not exist"

batch_infer_state = self.prepare_batch_state(input_tokens)
assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit"

# set BatchInferState for the current batch as attr to model
# NOTE this is not a preferable way to pass BatchInferState during inference
# we might want to rewrite generate function (e.g. _generate_by_pass_infer_state)
# and pass BatchInferState via model forward
model = self.sharded_model
if isinstance(model, LlamaForCausalLM):
model = self.sharded_model.model
elif isinstance(model, BloomForCausalLM):
model = self.sharded_model.transformer
setattr(model, 'infer_state', batch_infer_state)

outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)

# NOTE In future development, we're going to let the scheduler to handle the cache,
# instead of freeing space explicitly at the end of generation
self.cache_manager.free_all()

return outputs

# TODO might want to implement the func that generates output tokens by passing BatchInferState
# as an arg into model.forward
# requires rewriting model generate and replacing model forward
# as an arg into model.forward.
# It requires rewriting model generate and replacing model forward.
@torch.no_grad()
def generate_by_pass_infer_state(self,
input_tokens,
max_out_length: int,
generation_config: Optional[GenerationConfig] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
# if batch_size >= 4:
# assert self.sharded_model is not None, "sharded model does not exist"
# batch_infer_state = self.prepare_batch_state(input_tokens)
# batch_size = batch_infer_state.batch_size
# assert batch_infer_state.max_len_in_batch <= self.max_input_len
# # record sequences finish status, add early stopping, etc,
# for _ in range(min(max_out_length, self.max_output_len)):
# # ...
# self.sharded_model.forward(..., **model_kwargs)
# else:
# Use original model to generate
def _generate_by_pass_infer_state(self,
input_tokens,
max_out_length: int,
generation_config: Optional[GenerationConfig] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:

raise NotImplementedError("generate by passing BatchInferState is not implemented.")

# NOTE might want to use in rewritten generate method: use after model.forward
# might want to use in rewritten generate method: use after model.forward
# BatchInferState is created and kept during generation
# after each iter of model forward, we should update BatchInferState
def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
batch_size = infer_state.batch_size
device = infer_state.start_loc.device
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device)
infer_state.seq_len += 1

# TODO might want to create a sequence pool
# add a single request/sequence/input text at a time and record its length
# In other words, store the actual length of input tokens representing a single input text
# might want to create a sequence pool
# add a single request/sequence/input text at a time and record its length
# In other words, store the actual length of input tokens representing a single input text
# E.g. "Introduce landmarks in Beijing"
# => add request
# => record token length and other necessary information to be used
Expand Down
10 changes: 5 additions & 5 deletions examples/inference/bench_bloom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import time

import pytest
import torch
from transformers import BloomForCausalLM, BloomTokenizerFast

Expand Down Expand Up @@ -50,9 +49,11 @@ def bench_bloom(test_config):
model = model.half()

# init TPInferEngine and shard the original model
# To benchmark torch original, comment out lines of optimizing model
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model(test_config)
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model()

# prepare data for generation
batch_size = MAX_BATCH_SIZE
Expand Down Expand Up @@ -88,7 +89,6 @@ def check_bloom(rank, world_size, port):
bench_bloom()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
Expand Down
11 changes: 5 additions & 6 deletions examples/inference/bench_llama.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import time

import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.profiler import ProfilerActivity, profile, record_function
Expand Down Expand Up @@ -77,8 +75,10 @@ def run_llama_test(test_config):

model_config = model.config

infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model(test_config)
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model()

batch_size = 2
max_new_tokens = 128
Expand Down Expand Up @@ -111,7 +111,7 @@ def run_llama_test(test_config):
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
torch.cuda.synchronize()
outputs = infer_engine.generate(input_tokens, generate_kwargs)
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
torch.cuda.synchronize()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Expand All @@ -122,7 +122,6 @@ def check_llama(rank, world_size, port):
run_llama_test()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
Expand Down
Loading

0 comments on commit e2e96d4

Please sign in to comment.