-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[serving] Scaffolding for llm serving. (#409)
(needs bump for IREE runtime updates)
- Loading branch information
1 parent
e955627
commit f1c3d16
Showing
13 changed files
with
1,929 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
iree-compiler==20240207.794 | ||
iree-runtime==20240207.794 | ||
iree-compiler==20240215.802 | ||
iree-runtime==20240215.802 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import pytest | ||
|
||
from turbine_serving.framework.session import ( | ||
DeviceSession, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def local_device_session(): | ||
session = DeviceSession(uri="local-task") | ||
yield session | ||
session.shutdown() | ||
|
||
|
||
def test_start_shutdown_no_host_contexts(local_device_session: DeviceSession): | ||
ms = local_device_session.create_module_set("default") | ||
ms.initialize() | ||
|
||
|
||
def test_host_context_start_stop(local_device_session: DeviceSession): | ||
ms = local_device_session.create_module_set("default") | ||
ms.initialize() | ||
hc = ms.host_context | ||
|
||
|
||
def test_host_context_scheduling(local_device_session: DeviceSession): | ||
device = local_device_session.device | ||
ms = local_device_session.create_module_set("default") | ||
ms.initialize() | ||
hc = ms.host_context | ||
|
||
sem = device.create_semaphore(0) | ||
|
||
async def task1(): | ||
print("[coro1] test_host_context_scheduling.task") | ||
await hc.on_semaphore(sem, 1, True) | ||
print("[coro1] await completed") | ||
sem.signal(2) | ||
|
||
async def task2(): | ||
print("[coro2] waiting for 2") | ||
await hc.on_semaphore(sem, 2, True) | ||
sem.fail("Fail from task2") | ||
|
||
f1 = hc.run_concurrent(task1()) | ||
f2 = hc.run_concurrent(task2()) | ||
sem.signal(1) | ||
print("[main] Waiting for semaphore") | ||
|
||
# Ensure task completion. Important to consume to ensure that exceptions | ||
# propagate. | ||
f1.result() | ||
f2.result() | ||
|
||
print("[main] Waiting on semaphore payload 3") | ||
with pytest.raises(Exception, match="Fail from task2"): | ||
sem.wait(3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import pytest | ||
|
||
from iree.runtime import ( # type: ignore | ||
HalElementType, | ||
) | ||
|
||
from turbine_serving.framework.session import DeviceSession | ||
from turbine_serving.llm.config import ( | ||
CacheParams, | ||
ModelParams, | ||
ServiceParams, | ||
) | ||
|
||
from turbine_serving.llm.service import ( | ||
GenerateRequest, | ||
GenerateResponsePart, | ||
) | ||
|
||
from turbine_serving.llm.attn_block_cache import ( | ||
create_attn_block_cache_module, | ||
AttnBlockCache, | ||
) | ||
|
||
from turbine_serving.llm.impl.service_v1 import ( | ||
GenerateServiceV1, | ||
) | ||
|
||
from turbine_serving.llm.testing.fake_v1_module import ( | ||
create_fake_module, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def cache_params(model_params: ModelParams) -> CacheParams: | ||
return CacheParams(model=model_params, device_block_count=128, block_pos_stride=16) | ||
|
||
|
||
@pytest.fixture | ||
def model_params() -> ModelParams: | ||
return ModelParams( | ||
module_name="AwesomeLLM", | ||
module_abi_version=1, | ||
attn_dtype=HalElementType.FLOAT_16, | ||
max_seq_len=128, | ||
transformer_block_count=32, | ||
attn_head_count=32, | ||
attn_head_dim=128, | ||
prefill_batch_sizes=[1, 4, 16], | ||
decode_batch_sizes=[1, 4, 16], | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def uninitialized_session(model_params: ModelParams): | ||
from iree.runtime._binding import disable_leak_checker # type: ignore | ||
|
||
disable_leak_checker() | ||
session = DeviceSession(uri="local-task", queue_count=2) | ||
yield session | ||
session.shutdown() | ||
del session | ||
|
||
|
||
@pytest.fixture | ||
def attn_block_cache( | ||
uninitialized_session: DeviceSession, cache_params: CacheParams | ||
) -> AttnBlockCache: | ||
return AttnBlockCache(uninitialized_session, cache_params) | ||
|
||
|
||
@pytest.fixture | ||
def session( | ||
model_params: ModelParams, | ||
uninitialized_session: DeviceSession, | ||
attn_block_cache: AttnBlockCache, | ||
): | ||
session = uninitialized_session | ||
lms = session.create_module_set("AwesomeLLM", context_count=1) | ||
lms.add( | ||
create_attn_block_cache_module(attn_block_cache), | ||
create_fake_module(session.device, "AwesomeLLM", model_params=model_params), | ||
) | ||
lms.initialize() | ||
return session | ||
|
||
|
||
@pytest.fixture | ||
def service( | ||
session: DeviceSession, | ||
cache_params: CacheParams, | ||
model_params: ModelParams, | ||
attn_block_cache: AttnBlockCache, | ||
): | ||
params = ServiceParams(cache=cache_params, model=model_params) | ||
return GenerateServiceV1(session=session, params=params, cache=attn_block_cache) | ||
|
||
|
||
def test_single(service: GenerateServiceV1): | ||
state = service.start() | ||
|
||
async def task(): | ||
await state.set_sequences( | ||
requests=[ | ||
GenerateRequest( | ||
"1", | ||
"hello, tell me a story", | ||
[3, 4, 5, 12, 23, 88, 10, 2, 5, 9, 12, 13, 99, 56, 33, 124, 73], | ||
), | ||
GenerateRequest("2", "goodbye", [9, 10]), | ||
] | ||
) | ||
guarded_outputs = await state.prefill() | ||
prefill_ids = await guarded_outputs.resolve(state.host_context) | ||
print( | ||
"PREFILL IDS:", | ||
prefill_ids, | ||
":\n", | ||
prefill_ids.map().asarray( | ||
prefill_ids.shape, HalElementType.map_to_dtype(prefill_ids.element_type) | ||
), | ||
) | ||
await state.recycle() | ||
|
||
state.host_context.run_sync(task()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import logging | ||
import os | ||
import sys | ||
|
||
|
||
# Whether debug assertions are disabled. | ||
NDEBUG: bool = False | ||
|
||
_default_log_level = os.getenv("TURBINE_LOG_LEVEL", "DEBUG") | ||
|
||
|
||
class DefaultFormatter(logging.Formatter): | ||
def __init__(self): | ||
super().__init__( | ||
"%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s", | ||
"%m-%d %H:%M:%S", | ||
) | ||
|
||
|
||
def _setup_logger(): | ||
root_logger = logging.getLogger("turbine_serving") | ||
root_logger.setLevel(logging.DEBUG) | ||
default_handler = logging.StreamHandler(sys.stderr) | ||
default_handler.flush = sys.stderr.flush | ||
default_handler.setLevel(_default_log_level) | ||
default_handler.setFormatter(DefaultFormatter()) | ||
root_logger.addHandler(default_handler) | ||
root_logger.propagate = False | ||
return root_logger, default_handler | ||
|
||
|
||
root_logger, default_handler = _setup_logger() | ||
|
||
logging.getLogger("asyncio").addHandler(default_handler) | ||
|
||
|
||
def get_logger(name: str): | ||
logger = logging.getLogger(name) | ||
logger.setLevel(_default_log_level) | ||
logger.addHandler(default_handler) | ||
logger.propagate = False | ||
return logger |
Oops, something went wrong.