Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] User Experience: update the logic of default tokenizer and generation config. #5337

Merged
19 changes: 9 additions & 10 deletions colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ colossalai.launch_from_torch(config={})
# Step 1: create a model in "transformers" way
model_path = "lmsys/vicuna-7b-v1.3"
model = transformers.LlamaForCausalLM.from_pretrained(model_path).cuda()
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_path)
# tokenizer = transformers.LlamaTokenizer.from_pretrained(model_path)
# You can pre-define a tokenizer or use our default one
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

# Step 2: create an inference_config
inference_config = InferenceConfig(
Expand All @@ -100,14 +101,9 @@ inference_config = InferenceConfig(
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)

# Step 4: try inference
generation_config = transformers.GenerationConfig(
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=512,
)
prompts = ['Who is the best player in the history of NBA?']
engine.add_request(prompts=prompts)
response = engine.generate(generation_config)
pprint(response)
response = engine.generate(prompts)
print(response)
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
```

### :bookmark: Customize your inference engine
Expand Down Expand Up @@ -150,13 +146,16 @@ Notations:
- [x] Paged Attention
- [x] High-Performance Kernels
- [x] Llama Modelling
- [x] User Documentation
- [ ] Speculative Decoding
- [ ] Tensor Parallelism
- [ ] Beam Search
- [ ] Speculative Decoding
- [ ] Early stopping
- [ ] Logger system
- [ ] SplitFuse
- [ ] Continuous Batching
- [ ] Online Inference
- [ ] Benchmarking
- [ ] User Documentation

## 🌟 Acknowledgement

Expand Down
33 changes: 33 additions & 0 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch.distributed as dist
from transformers.generation import GenerationConfig

GibiByte = 1024**3

Expand Down Expand Up @@ -46,6 +47,8 @@ class InferenceConfig:
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
"""

model: str = "Llama"
tokenizer: str = None
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
max_batch_size: int = 8
Expand All @@ -56,15 +59,21 @@ class InferenceConfig:
tp_size: int = 1
pp_size: int = 1
# TODO: beam search is not support for now
do_sample: bool = False
beam_width: int = 1
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
prefill_ratio: Optional[float] = 1.2
pad_input: bool = False
quant_mode: Optional[str] = None
revision: Optional[str] = None
early_stopping: Optional[bool] = False
trust_remote_code = False
tokenizer_revision = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these two configs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

tokenizer_mode = "auto"

def __post_init__(self):
self._verify_config()
self._get_tokenizer()

def _verify_config(self) -> None:
"""
Expand All @@ -85,3 +94,27 @@ def _verify_config(self) -> None:
assert (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"

def _get_tokenizer(self) -> None:
if self.tokenizer is not None:
return
model_name = self.model.lower()
if "llama" in model_name:
self.tokenizer = "hf-internal-testing/llama-tokenizer"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to deal with tokenizer in InferenceConfig, always obey the single-resposibility principle. As I can see, the engine will receive a tokenizer, let the engine alone handle that, inference config should not handle tokenizer anymore to reduce coupling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the model argument seems unnecessary to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


def _to_generation_config(self, model_config) -> GenerationConfig:
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
meta_config = {
"max_length": self.max_input_len + self.max_output_len,
"max_new_tokens": self.max_output_len,
"early_stopping": self.early_stopping,
"do_sample": self.do_sample,
"num_beams": self.beam_width,
}
for type in ["top_k", "top_p", "min_p"]:
if hasattr(self, type):
meta_config[type] = getattr(self, type)
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:
if hasattr(self, type):
meta_config[type] = getattr(self, type)

return GenerationConfig.from_dict(meta_config)
31 changes: 22 additions & 9 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from colossalai.inference.config import InferenceConfig
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import Sequence
from colossalai.inference.tokenizer import get_tokenizer
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
Expand All @@ -33,7 +34,7 @@ class InferenceEngine:

Args:
model (nn.Module): Path or nn.Module of this model.
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
Expand All @@ -42,19 +43,26 @@ class InferenceEngine:
def __init__(
self,
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: Optional["InferenceConfig"] = None,
inference_config: InferenceConfig,
tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,
verbose: bool = False,
model_policy: Policy = None,
) -> None:
assert inference_config, "Please provide inference_config."
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.inference_config = inference_config
self.model_config = model.config
self.device = torch.device("cuda")
self.dtype = inference_config.dtype

if tokenizer is None:
tokenizer = get_tokenizer(
inference_config.tokenizer,
trust_remote_code=inference_config.trust_remote_code,
tokenizer_revision=inference_config.tokenizer_revision,
revision=inference_config.revision,
)
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config._to_generation_config(self.model_config)
model = model.eval()
model.to(self.dtype)

Expand All @@ -80,6 +88,8 @@ def __init__(

self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?

self.counter = count()

def _verify_config(self) -> None:
Expand Down Expand Up @@ -136,7 +146,7 @@ def generate(
self,
prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
generation_config: GenerationConfig = None,
generation_config: Optional[GenerationConfig] = None,
) -> List[str]:
"""
Executing the inference step.
Expand All @@ -150,7 +160,10 @@ def generate(
List[str]: Inference result returned by one generation.
"""

self.generation_config = generation_config
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config

if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)

Expand Down Expand Up @@ -260,8 +273,8 @@ def step(self) -> List[str]:

if self.inference_config.pad_input:
logits = logits[:, -1, :]

self.request_handler.search_tokens(self.generation_config, logits)

finished_sequences = self.request_handler.update()

return finished_sequences
11 changes: 8 additions & 3 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig

from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
Expand Down Expand Up @@ -94,6 +95,10 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo
head_dim = model_config.hidden_size // model_config.num_attention_heads

fd_inter_tensor = FDIntermTensors()

if fd_inter_tensor._tensors_initialized:
fd_inter_tensor._re_initialize()
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
num_attn_heads=model_config.num_attention_heads,
Expand Down Expand Up @@ -229,7 +234,7 @@ def _find_sequence(self, request_id: str) -> Sequence:

return None

def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config):
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = multinomial_sample(generation_config, probs)
Expand All @@ -240,7 +245,7 @@ def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config

return sample_tokens

def mark_finished(self, sequence: Sequence, generation_config):
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig):
if (
sequence.output_token_id[-1] == generation_config.eos_id
or sequence.output_len >= generation_config.max_output_len
Expand All @@ -250,7 +255,7 @@ def mark_finished(self, sequence: Sequence, generation_config):
def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()

def search_tokens(self, generation_config, logits):
def search_tokens(self, generation_config: GenerationConfig, logits):
"""
Sample tokens for finished requests.
"""
Expand Down
5 changes: 5 additions & 0 deletions colossalai/inference/flash_decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ class FDIntermTensors(metaclass=SingletonMeta):
def __init__(self):
self._tensors_initialized = False

def _re_initialize(self):
self._tensors_initialized = False
del self._mid_output
del self._mid_output_lse
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

@property
def is_initialized(self):
return self._tensors_initialized
Expand Down
1 change: 0 additions & 1 deletion colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def llama_model_forward(
"""
input_ids = batch.get_1D_inputs()
block_tables = batch.get_block_table_tensor()

sequence_lengths = batch.get_sequence_lengths()
batch_size = len(sequence_lengths)
kv_seq_len = sequence_lengths.max().item()
Expand Down
53 changes: 53 additions & 0 deletions colossalai/inference/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
For default tokenizer creation
Adapted from vllm/transformers_utils/tokenizer.py
"""
from typing import Optional, Union
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast


def get_tokenizer(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""
Returns a default tokenizer if not provided by user.
"""
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False

try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, *args, trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, **kwargs
)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)
):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI."
)
raise RuntimeError(err_msg) from e
else:
raise e

if not isinstance(tokenizer, PreTrainedTokenizerFast):
print(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
)
# TODO: set a loggger and divide levels.
return tokenizer
2 changes: 1 addition & 1 deletion examples/inference/benchmark_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def benchmark_inference(args):
max_output_len=args.output_len,
prefill_ratio=1.2,
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
engine = InferenceEngine(model, inference_config, tokenizer, verbose=True)
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
else:
engine = model

Expand Down
15 changes: 13 additions & 2 deletions tests/test_infer/test_inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from copy import deepcopy

import numpy as np
import pytest
Expand All @@ -18,6 +19,15 @@ def setup_seed(seed):
random.seed(seed)


def check_config_tokenizer(n_model, output_len):
drat_config = InferenceConfig(n_model.__class__.__name__, max_output_len=output_len, dtype=torch.float32)
draf_engine = InferenceEngine(n_model, inference_config=drat_config)

assert "transformers.models.llama" in str(draf_engine.tokenizer.__class__)
assert draf_engine.generation_config.max_new_tokens == output_len
del draf_engine


def check_inference_engine(test_cai=False):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
Expand All @@ -30,7 +40,6 @@ def check_inference_engine(test_cai=False):
.cuda()
.half()
)

model = model.eval()

inputs = [
Expand All @@ -44,8 +53,10 @@ def check_inference_engine(test_cai=False):
top_k = 50

if test_cai:
n_model = deepcopy(model)
check_config_tokenizer(n_model, output_len)
inference_config = InferenceConfig(max_output_len=output_len)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_engine = InferenceEngine(model, inference_config, tokenizer, verbose=True)
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
Expand Down
Loading