Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Infer] Fix TPInferEngine init & inference tests, benchmarks #4670

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 106 additions & 78 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,38 @@


class TPInferEngine:
"""Engine class for tensor parallel inference.

Args:
model (Module): original model, e.g. huggingface CausalLM
shard_config (ShardConfig): The config for sharding original model
max_batch_size (int): maximum batch size
max_input_len (int): maximum input length of sequence
max_output_len (int): maximum output length of output tokens
dtype (torch.dtype): datatype used to init KV cache space
device (str): device the KV cache of engine to be initialized on

Examples:
>>> # define model and shard config for your inference
>>> model = ...
>>> generate_kwargs = ...
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
>>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
>>> infer_engine.optimize_model()
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
"""

def __init__(self,
model: nn.Module,
shard_config: ShardConfig,
max_batch_size: int,
max_input_len: int,
max_output_len: int,
dtype: torch.dtype = torch.float16,
device: str = 'cuda') -> None:
self.model = model
self.model = self.model.to(device)
self.shard_config = shard_config
self.sharded_model = None

self.max_batch_size = max_batch_size
Expand Down Expand Up @@ -57,19 +79,25 @@ def _init_manager(self) -> None:
self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
self.layer_num)

def optimize_model(self, config: Optional[Dict[Any, Any]] = None) -> None:
""" Apply shardformer to optimize the model. In future generation, use sharded model instead of original model. """
tp_size = 1 if config is None else config.get('tp_size', 1)
def optimize_model(self) -> None:
"""
Optimize the original model by sharding with ShardFormer.
In further generation, use the sharded model instead of original model.
"""
# NOTE we will change to use an inference config later with additional attrs we want
# tp_size = getattr(config, 'tp_size', 1)
shard_config = ShardConfig(enable_tensor_parallelism=True if tp_size > 1 else False, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)
self._prepare_with_shard_config(shard_config=shard_config)
assert self.shard_config.inference_only is True
shardformer = ShardFormer(shard_config=self.shard_config)
self._prepare_with_shard_config(shard_config=self.shard_config)
self._shard_model_by(shardformer)
self.model = None

def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
""" Prepare the engine with a given ShardConfig, or create a default one with tp size 1 """
""" Prepare the engine with a given ShardConfig.

Args:
shard_config (ShardConfig): shard config given to specify settings of the engine.
If not provided, a default ShardConfig with tp size 1 will be created.
"""
self.tp_size = 1
if shard_config is None:
shard_config = ShardConfig(
Expand All @@ -92,20 +120,30 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None)
return shard_config

def _shard_model_by(self, shardformer: ShardFormer) -> None:
""" Shard the model and store the sharded model by given ShardFormer """
""" Shard original model by the given ShardFormer and store the sharded model. """
assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
"Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = self.model.__class__.__name__
assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference."
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(self.model, inference_only=True)
self.sharded_model, _ = shardformer.optimize(self.model, policy)
self.sharded_model = self.sharded_model.cuda()

@staticmethod
def _supported_models() -> List[str]:
@property
def supported_models(self) -> List[str]:
return _supported_models

def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor:
"""Generate token sequence.

Args:
input_tokens: could be one of the following types
1. BatchEncoding or dict (e.g. tokenizer batch_encode)
2. list of input token ids (e.g. appended result of tokenizer encode)
3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
Returns:
torch.Tensor: The returned sequence is given inputs + generated_tokens.
"""
if isinstance(input_tokens, torch.Tensor):
input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool))
for t in input_tokens:
Expand All @@ -115,51 +153,14 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor],
generate_kwargs.update(max_new_tokens=self.max_output_len)

if self.sharded_model is not None:
return self.generate_by_set_infer_state(input_tokens, **generate_kwargs)

return self.model.generate(input_tokens.get('input_ids'), **generate_kwargs)
return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)

@torch.no_grad()
def generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor:
"""
Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate

Args:
inputs: should be one of the following types
1. BatchEncoding or dict (e.g. tokenizer batch_encode)
2. list of input token ids (e.g. appended result of tokenizer encode)
3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
"""

# for testing, always use sharded model
assert self.sharded_model is not None, "sharded model does not exist"

batch_infer_state = self.prepare_batch_state(input_tokens)
assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit"

# set BatchInferState for the current batch as attr to model
# NOTE this is not an expectable way to pass BatchInferState during inference
# we might want to rewrite generate function (e.g. generate_by_pass_infer_state)
# and pass BatchInferState via model forward
model = self.sharded_model
if isinstance(model, LlamaForCausalLM):
model = self.sharded_model.model
elif isinstance(model, BloomForCausalLM):
model = self.sharded_model.transformer
setattr(model, 'infer_state', batch_infer_state)

outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)

# NOTE In future development, we're going to let the scheduler to handle the cache,
# instead of freeing space explicitly at the end of generation
self.cache_manager.free_all()

return outputs
return self.model.generate(**input_tokens, **generate_kwargs)

def prepare_batch_state(self, inputs) -> BatchInferState:
"""
Create and prepare BatchInferState used for inference during model forwrad,
by processing each sequence of the given inputs
by processing each sequence of the given inputs.

Args:
inputs: should be one of the following types
Expand Down Expand Up @@ -216,7 +217,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device
batch_infer_state.seq_len = seq_lengths.to('cuda')
batch_infer_state.start_loc = seq_start_indexes.to('cuda')
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
Expand All @@ -225,42 +226,69 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
batch_infer_state.set_cache_manager(self.cache_manager)
return batch_infer_state

@torch.no_grad()
def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor:
"""
Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate

Args:
inputs: should be one of the following types
1. BatchEncoding or dict (e.g. tokenizer batch_encode)
2. list of input token ids (e.g. appended result of tokenizer encode)
3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
"""

# for testing, always use sharded model
assert self.sharded_model is not None, "sharded model does not exist"

batch_infer_state = self.prepare_batch_state(input_tokens)
assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit"

# set BatchInferState for the current batch as attr to model
# NOTE this is not a preferable way to pass BatchInferState during inference
# we might want to rewrite generate function (e.g. _generate_by_pass_infer_state)
# and pass BatchInferState via model forward
model = self.sharded_model
if isinstance(model, LlamaForCausalLM):
model = self.sharded_model.model
elif isinstance(model, BloomForCausalLM):
model = self.sharded_model.transformer
setattr(model, 'infer_state', batch_infer_state)

outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)

# NOTE In future development, we're going to let the scheduler to handle the cache,
# instead of freeing space explicitly at the end of generation
self.cache_manager.free_all()

return outputs

# TODO might want to implement the func that generates output tokens by passing BatchInferState
# as an arg into model.forward
# requires rewriting model generate and replacing model forward
# as an arg into model.forward.
# It requires rewriting model generate and replacing model forward.
@torch.no_grad()
def generate_by_pass_infer_state(self,
input_tokens,
max_out_length: int,
generation_config: Optional[GenerationConfig] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
# if batch_size >= 4:
# assert self.sharded_model is not None, "sharded model does not exist"
# batch_infer_state = self.prepare_batch_state(input_tokens)
# batch_size = batch_infer_state.batch_size
# assert batch_infer_state.max_len_in_batch <= self.max_input_len
# # record sequences finish status, add early stopping, etc,
# for _ in range(min(max_out_length, self.max_output_len)):
# # ...
# self.sharded_model.forward(..., **model_kwargs)
# else:
# Use original model to generate
def _generate_by_pass_infer_state(self,
input_tokens,
max_out_length: int,
generation_config: Optional[GenerationConfig] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:

raise NotImplementedError("generate by passing BatchInferState is not implemented.")

# NOTE might want to use in rewritten generate method: use after model.forward
# might want to use in rewritten generate method: use after model.forward
# BatchInferState is created and kept during generation
# after each iter of model forward, we should update BatchInferState
def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
batch_size = infer_state.batch_size
device = infer_state.start_loc.device
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device)
infer_state.seq_len += 1

# TODO might want to create a sequence pool
# add a single request/sequence/input text at a time and record its length
# In other words, store the actual length of input tokens representing a single input text
# might want to create a sequence pool
# add a single request/sequence/input text at a time and record its length
# In other words, store the actual length of input tokens representing a single input text
# E.g. "Introduce landmarks in Beijing"
# => add request
# => record token length and other necessary information to be used
Expand Down
10 changes: 5 additions & 5 deletions examples/inference/bench_bloom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import time

import pytest
import torch
from transformers import BloomForCausalLM, BloomTokenizerFast

Expand Down Expand Up @@ -50,9 +49,11 @@ def bench_bloom(test_config):
model = model.half()

# init TPInferEngine and shard the original model
# To benchmark torch original, comment out lines of optimizing model
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model(test_config)
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model()

# prepare data for generation
batch_size = MAX_BATCH_SIZE
Expand Down Expand Up @@ -88,7 +89,6 @@ def check_bloom(rank, world_size, port):
bench_bloom()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
Expand Down
11 changes: 5 additions & 6 deletions examples/inference/bench_llama.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import time

import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.profiler import ProfilerActivity, profile, record_function
Expand Down Expand Up @@ -77,8 +75,10 @@ def run_llama_test(test_config):

model_config = model.config

infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model(test_config)
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model()

batch_size = 2
max_new_tokens = 128
Expand Down Expand Up @@ -111,7 +111,7 @@ def run_llama_test(test_config):
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
torch.cuda.synchronize()
outputs = infer_engine.generate(input_tokens, generate_kwargs)
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
torch.cuda.synchronize()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Expand All @@ -122,7 +122,6 @@ def check_llama(rank, world_size, port):
run_llama_test()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
Expand Down
Loading
Loading