Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Update inference config and fix test #5178

Merged
merged 9 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
import torch.nn as nn

GibiByte = 1024**3


@dataclass
class InferenceConfig:
Expand Down Expand Up @@ -43,8 +45,24 @@ class InferenceConfig:
max_seq_len: Optional[int] = None
quant_mode: Optional[str] = None
revision: Optional[str] = None
# TODO: beam search is not support for now
beam_width: int = 1
# TODO: beam search is not support for now
ratio: Optional[float] = 1.2
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio

def __init_batch_size__(self):
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
"""
MAX_BATCH_SIZE is set to acurately utilize the memory of gpu.
We take a simple method to determine it by GPU memory size, user can still set it manually.
"""
device = torch.device("cuda")
total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte
self.max_batch_size = 8

if 40 < total_mem <= 60:
self.max_batch_size = 16
elif 60 < total_mem <= 80:
self.max_batch_size = 32
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self):
self._verify_args()
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from transformers import AutoConfig

from .config import InferenceConfig
from colossalai.inference.config import InferenceConfig


class InferenceEngine:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from dataclasses import dataclass
from typing import Dict, List, Set

"""
The abstraction of request and sequence are defined here.
"""
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved


class RequsetStatus(enum.Enum):
"""The status of Sentences"""
Expand Down Expand Up @@ -95,16 +99,16 @@ def __repr__(self) -> str:


@dataclass
class BatchHandler:
class BatchInfo:
"""
Information to be passed and used for a batch of sequences.
"""

sequences_set: Set[Sequence]
block_table: Dict[int, int]
block_table: Dict[int, int] = None

@classmethod
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
"""
Initializes inference batches by input sentence list.

Expand All @@ -115,13 +119,13 @@ def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
block_table = {}
for seq in seqs:
if seq in sequences_set:
print("The sequence is already in sequences_set.")
assert (
seq.request_id in block_table
seq.request_id in block_table.keys()
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue

assert (
seq.request_id not in block_table
seq.request_id not in block_table.keys()
), "The sequence has not been added to sequences_set, but it is already in block_table."

sequences_set.add(seq)
Expand All @@ -143,9 +147,9 @@ def fliter_batch(self) -> None:
"""
Remove completed sentences from a batch.
"""
for seq in self.sequences_set:
for seq in self.sequences_set.copy():
if seq.check_finish():
self.sequences_set.reomve(seq)
self.sequences_set.remove(seq)
del self.block_table[seq.request_id]

def add_seqs(self, seqs: List[Sequence]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from transformers.configuration_utils import PretrainedConfig

from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.config import InferenceConfig
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device

Expand Down
3 changes: 0 additions & 3 deletions colossalai/inference/sequence.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.core.inference_struct import BatchHandler, Sequence
from colossalai.inference.config import InferenceConfig
from colossalai.inference.inferenceData import BatchInfo, RequsetStatus, Sequence


def test_config_and_struct():
def test_config_and_inferenceData():
InferenceConfig("/llama")
sequence = Sequence(
request_id=1,
Expand All @@ -27,11 +27,16 @@ def test_config_and_struct():
assert sequence.get_output_len() == 0
assert sequence.check_finish() == False

batch = BatchHandler.init_batch([sequence])
batch = BatchInfo.init_batch([sequence])
assert batch.block_table[sequence.request_id] == sequence.block_table_index
sequence.status = RequsetStatus.COMPLETED
batch.fliter_batch()
assert batch.block_table == {}
batch.add_seqs([sequence2])
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
batch.clear_batch()
assert batch.block_table == {}


if __name__ == "__main__":
test_config_and_struct()
test_config_and_inferenceData()
2 changes: 1 addition & 1 deletion tests/test_infer/test_kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from transformers.models.llama import LlamaConfig

from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize
Expand Down
Loading