Skip to content

Commit

Permalink
[Inference]Update inference config and fix test (#5178)
Browse files Browse the repository at this point in the history
* unify the config setting

* fix test

* fix import

* fix test

* fix

* fix

* add logger

* revise log info

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
  • Loading branch information
CjhHa1 authored and FrankLeeeee committed Jan 11, 2024
1 parent 3de2e62 commit 93aeacc
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import logging
from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.nn as nn

GibiByte = 1024**3

logger = logging.Logger(__name__)


@dataclass
class InferenceConfig:
Expand All @@ -18,7 +23,6 @@ class InferenceConfig:
max_output_len: Maximum output length.
max_input_len: Maximum input length.
block_size: The number of blocks in a logical block.
gpu_utilization_rate: Maximum GPU memory usage ratio.
dtype: The data type for weights and activations.
tp_size: Tensor parallel size.
pp_size: Pipeline parallel size.
Expand All @@ -27,13 +31,15 @@ class InferenceConfig:
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
beam_width: The maximum beam width used to initialize KV Cache.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
when the actual value exceeds this ratio.
"""

model: Union[str, nn.Module]
tokenizer: str = None
tokenizer_mode: str = "auto"
trust_remote_code: bool = False
max_batch_size: int = 8
max_batch_size: int = None
max_output_len: int = 256
max_input_len: int = 256
block_size: int = 16
Expand All @@ -43,10 +49,34 @@ class InferenceConfig:
max_seq_len: Optional[int] = None
quant_mode: Optional[str] = None
revision: Optional[str] = None
# TODO: beam search is not support for now
beam_width: int = 1
# TODO: beam search is not support for now
prefill_ratio: Optional[float] = 1.2
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio

def _init_batch_size(self):
"""
MAX_BATCH_SIZE is set to acurately utilize the memory of gpu.
We take a simple method to determine it by GPU memory size, user can still set it manually.
"""
if self.max_batch_size is not None:
# already set by user
return

device = torch.device("cuda")
total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte
self.max_batch_size = 8

if 40 < total_mem <= 60:
self.max_batch_size = 16
elif 60 < total_mem <= 80:
self.max_batch_size = 32
logger.info(
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
)

def __post_init__(self):
self._init_batch_size()
self._verify_args()

def _verify_args(self):
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from transformers import AutoConfig

from .config import InferenceConfig
from colossalai.inference.config import InferenceConfig


class InferenceEngine:
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from transformers.configuration_utils import PretrainedConfig

from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.config import InferenceConfig
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device

Expand Down
3 changes: 1 addition & 2 deletions colossalai/inference/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ Colossal-Infer is a library for inference of LLMs and MLMs. It is built on top o

## Structures
### Overview
https://n4fyd3ptax.feishu.cn/docx/MhlmdHsGkoeoslx9fqucPO17n9b?openbrd=1&doc_app_id=501&blockId=WCGBdWI9hobOEsxkW5uc8HM6n3b&blockType=whiteboard&blockToken=Cca3wKWk7hPnJxbkCX6cMxPQnqd#WCGBdWI9hobOEsxkW5uc8HM6n3b

The main design will be released later on.
## Roadmap
- [] design of structures
- [] Core components
Expand Down
3 changes: 0 additions & 3 deletions colossalai/inference/sequence.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from dataclasses import dataclass
from typing import Dict, List, Set

"""
The abstraction of request and sequence are defined here.
"""


class RequsetStatus(enum.Enum):
"""The status of Sentences"""
Expand Down Expand Up @@ -95,16 +99,16 @@ def __repr__(self) -> str:


@dataclass
class BatchHandler:
class BatchInfo:
"""
Information to be passed and used for a batch of sequences.
"""

sequences_set: Set[Sequence]
block_table: Dict[int, int]
block_table: Dict[int, int] = None

@classmethod
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
"""
Initializes inference batches by input sentence list.
Expand All @@ -115,13 +119,13 @@ def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
block_table = {}
for seq in seqs:
if seq in sequences_set:
print("The sequence is already in sequences_set.")
assert (
seq.request_id in block_table
seq.request_id in block_table.keys()
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue

assert (
seq.request_id not in block_table
seq.request_id not in block_table.keys()
), "The sequence has not been added to sequences_set, but it is already in block_table."

sequences_set.add(seq)
Expand All @@ -143,9 +147,9 @@ def fliter_batch(self) -> None:
"""
Remove completed sentences from a batch.
"""
for seq in self.sequences_set:
for seq in self.sequences_set.copy():
if seq.check_finish():
self.sequences_set.reomve(seq)
self.sequences_set.remove(seq)
del self.block_table[seq.request_id]

def add_seqs(self, seqs: List[Sequence]) -> None:
Expand Down
18 changes: 12 additions & 6 deletions tests/test_infer/test_config_and_struct.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.core.inference_struct import BatchHandler, Sequence
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence


def test_config_and_struct():
InferenceConfig("/llama")
def test_config_and_inferenceData():
config = InferenceConfig("/llama")
assert config.max_batch_size
sequence = Sequence(
request_id=1,
prompt="abc",
Expand All @@ -27,11 +28,16 @@ def test_config_and_struct():
assert sequence.get_output_len() == 0
assert sequence.check_finish() == False

batch = BatchHandler.init_batch([sequence])
batch = BatchInfo.init_batch([sequence])
assert batch.block_table[sequence.request_id] == sequence.block_table_index
sequence.status = RequsetStatus.COMPLETED
batch.fliter_batch()
assert batch.block_table == {}
batch.add_seqs([sequence2])
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
batch.clear_batch()
assert batch.block_table == {}


if __name__ == "__main__":
test_config_and_struct()
test_config_and_inferenceData()
2 changes: 1 addition & 1 deletion tests/test_infer/test_kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from transformers.models.llama import LlamaConfig

from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize
Expand Down

0 comments on commit 93aeacc

Please sign in to comment.