Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Update inference config and fix test #5178

Merged
merged 9 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import logging
from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.nn as nn

GibiByte = 1024**3

logger = logging.Logger(__name__)


@dataclass
class InferenceConfig:
Expand All @@ -18,7 +23,6 @@ class InferenceConfig:
max_output_len: Maximum output length.
max_input_len: Maximum input length.
block_size: The number of blocks in a logical block.
gpu_utilization_rate: Maximum GPU memory usage ratio.
dtype: The data type for weights and activations.
tp_size: Tensor parallel size.
pp_size: Pipeline parallel size.
Expand All @@ -27,13 +31,15 @@ class InferenceConfig:
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
beam_width: The maximum beam width used to initialize KV Cache.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
when the actual value exceeds this ratio.
"""

model: Union[str, nn.Module]
tokenizer: str = None
tokenizer_mode: str = "auto"
trust_remote_code: bool = False
max_batch_size: int = 8
max_batch_size: int = None
max_output_len: int = 256
max_input_len: int = 256
block_size: int = 16
Expand All @@ -43,10 +49,34 @@ class InferenceConfig:
max_seq_len: Optional[int] = None
quant_mode: Optional[str] = None
revision: Optional[str] = None
# TODO: beam search is not support for now
beam_width: int = 1
# TODO: beam search is not support for now
prefill_ratio: Optional[float] = 1.2
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio

def _init_batch_size(self):
CjhHa1 marked this conversation as resolved.
Show resolved Hide resolved
"""
MAX_BATCH_SIZE is set to acurately utilize the memory of gpu.
We take a simple method to determine it by GPU memory size, user can still set it manually.
"""
if self.max_batch_size is not None:
# already set by user
return

device = torch.device("cuda")
total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte
self.max_batch_size = 8

if 40 < total_mem <= 60:
self.max_batch_size = 16
elif 60 < total_mem <= 80:
self.max_batch_size = 32
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
logger.info(
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
)

def __post_init__(self):
self._init_batch_size()
self._verify_args()

def _verify_args(self):
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from transformers import AutoConfig

from .config import InferenceConfig
from colossalai.inference.config import InferenceConfig


class InferenceEngine:
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from transformers.configuration_utils import PretrainedConfig

from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.config import InferenceConfig
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device

Expand Down
3 changes: 1 addition & 2 deletions colossalai/inference/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ Colossal-Infer is a library for inference of LLMs and MLMs. It is built on top o

## Structures
### Overview
https://n4fyd3ptax.feishu.cn/docx/MhlmdHsGkoeoslx9fqucPO17n9b?openbrd=1&doc_app_id=501&blockId=WCGBdWI9hobOEsxkW5uc8HM6n3b&blockType=whiteboard&blockToken=Cca3wKWk7hPnJxbkCX6cMxPQnqd#WCGBdWI9hobOEsxkW5uc8HM6n3b

The main design will be released later on.
## Roadmap
- [] design of structures
- [] Core components
Expand Down
3 changes: 0 additions & 3 deletions colossalai/inference/sequence.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from dataclasses import dataclass
from typing import Dict, List, Set

"""
The abstraction of request and sequence are defined here.
"""


class RequsetStatus(enum.Enum):
"""The status of Sentences"""
Expand Down Expand Up @@ -95,16 +99,16 @@ def __repr__(self) -> str:


@dataclass
class BatchHandler:
class BatchInfo:
"""
Information to be passed and used for a batch of sequences.
"""

sequences_set: Set[Sequence]
block_table: Dict[int, int]
block_table: Dict[int, int] = None

@classmethod
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
"""
Initializes inference batches by input sentence list.

Expand All @@ -115,13 +119,13 @@ def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
block_table = {}
for seq in seqs:
if seq in sequences_set:
print("The sequence is already in sequences_set.")
assert (
seq.request_id in block_table
seq.request_id in block_table.keys()
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue

assert (
seq.request_id not in block_table
seq.request_id not in block_table.keys()
), "The sequence has not been added to sequences_set, but it is already in block_table."

sequences_set.add(seq)
Expand All @@ -143,9 +147,9 @@ def fliter_batch(self) -> None:
"""
Remove completed sentences from a batch.
"""
for seq in self.sequences_set:
for seq in self.sequences_set.copy():
if seq.check_finish():
self.sequences_set.reomve(seq)
self.sequences_set.remove(seq)
del self.block_table[seq.request_id]

def add_seqs(self, seqs: List[Sequence]) -> None:
Expand Down
18 changes: 12 additions & 6 deletions tests/test_infer/test_config_and_struct.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.core.inference_struct import BatchHandler, Sequence
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence


def test_config_and_struct():
InferenceConfig("/llama")
def test_config_and_inferenceData():
config = InferenceConfig("/llama")
assert config.max_batch_size
sequence = Sequence(
request_id=1,
prompt="abc",
Expand All @@ -27,11 +28,16 @@ def test_config_and_struct():
assert sequence.get_output_len() == 0
assert sequence.check_finish() == False

batch = BatchHandler.init_batch([sequence])
batch = BatchInfo.init_batch([sequence])
assert batch.block_table[sequence.request_id] == sequence.block_table_index
sequence.status = RequsetStatus.COMPLETED
batch.fliter_batch()
assert batch.block_table == {}
batch.add_seqs([sequence2])
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
batch.clear_batch()
assert batch.block_table == {}


if __name__ == "__main__":
test_config_and_struct()
test_config_and_inferenceData()
2 changes: 1 addition & 1 deletion tests/test_infer/test_kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from transformers.models.llama import LlamaConfig

from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize
Expand Down
Loading