From 84b9084c91e15f1d839090faceb250a4294d97d6 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 3 Jun 2024 16:14:05 -0600 Subject: [PATCH] :sparkles: pipe prompt reqs in server Signed-off-by: Joe Runde --- vllm/engine/async_llm_engine.py | 2 ++ vllm/entrypoints/grpc/adapters.py | 33 ++++++++++++++++++++-------- vllm/entrypoints/grpc/grpc_server.py | 3 ++- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index dd719575..9f0dea10 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -615,6 +615,8 @@ async def generate( sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: PromptAdapterRequest to use for generation, + if any. Yields: The output `RequestOutput` objects from the LLMEngine diff --git a/vllm/entrypoints/grpc/adapters.py b/vllm/entrypoints/grpc/adapters.py index 2c2387d9..0c26de0f 100644 --- a/vllm/entrypoints/grpc/adapters.py +++ b/vllm/entrypoints/grpc/adapters.py @@ -11,6 +11,7 @@ SingleGenerationRequest) from vllm.entrypoints.grpc.validation import TGISValidationError from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest global_thread_pool = None # used for loading adapter files from disk @@ -20,6 +21,7 @@ class AdapterMetadata: unique_id: int # Unique integer for vllm to identify the adapter adapter_type: str # The string name of the peft adapter type, e.g. LORA full_path: str + full_config: Dict # The loaded adapter_config.json dict @dataclasses.dataclass @@ -30,8 +32,9 @@ class AdapterStore: async def validate_adapters( - request: Union[SingleGenerationRequest, BatchedGenerationRequest], - adapter_store: Optional[AdapterStore]) -> Dict[str, LoRARequest]: + request: Union[SingleGenerationRequest, BatchedGenerationRequest], + adapter_store: Optional[AdapterStore] +) -> Dict[str, Union[LoRARequest, PromptAdapterRequest]]: """Takes the adapter name from the request and constructs a valid engine request if one is set. Raises if the requested adapter does not exist or adapter type is unsupported @@ -40,6 +43,9 @@ async def validate_adapters( """ global global_thread_pool adapter_id = request.adapter_id + # Backwards compatibility for `prefix_id` arg + if not adapter_id and request.prefix_id: + adapter_id = request.prefix_id if adapter_id and not adapter_store: TGISValidationError.AdaptersDisabled.error() @@ -57,16 +63,17 @@ async def validate_adapters( global_thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=2) - adapter_type = await loop.run_in_executor(global_thread_pool, - _get_adapter_type_from_file, - adapter_id, - local_adapter_path) + adapter_config = await (loop.run_in_executor( + global_thread_pool, _load_adapter_config_from_file, adapter_id, + local_adapter_path)) + adapter_type = adapter_config.get("peft_type", None) # Add to cache adapter_metadata = AdapterMetadata( unique_id=adapter_store.next_unique_id, adapter_type=adapter_type, - full_path=local_adapter_path) + full_path=local_adapter_path, + full_config=adapter_config) adapter_store.adapters[adapter_id] = adapter_metadata # Build the proper vllm request object @@ -75,12 +82,20 @@ async def validate_adapters( lora_int_id=adapter_metadata.unique_id, lora_local_path=adapter_metadata.full_path) return {"lora_request": lora_request} + elif adapter_metadata.adapter_type == "PROMPT_TUNING": + prompt_adapter_request = PromptAdapterRequest( + prompt_adapter_id=adapter_metadata.unique_id, + prompt_adapter_name=adapter_id, + prompt_adapter_local_path=adapter_metadata.full_path, + prompt_adapter_num_virtual_tokens=adapter_metadata.full_config.get( + "num_virtual_tokens", 0)) + return {"prompt_adapter_request": prompt_adapter_request} # All other types unsupported TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) -def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str: +def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> Dict: """This function does all the filesystem access required to deduce the type of the adapter. It's run in a separate thread pool executor so that file access does not block the main event loop.""" @@ -97,4 +112,4 @@ def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str: with open(adapter_config_path) as adapter_config_file: adapter_config = json.load(adapter_config_file) - return adapter_config.get("peft_type", None) + return adapter_config diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index 07a49196..c2720ab2 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -35,6 +35,7 @@ from vllm.inputs import TextTokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import Logprob from vllm.tgis_utils import logs from vllm.tgis_utils.guided_decoding import ( @@ -459,7 +460,7 @@ async def _validate_adapters(self, request: Union[SingleGenerationRequest, BatchedGenerationRequest], context: ServicerContext) \ - -> Dict[str, LoRARequest]: + -> Dict[str, Union[LoRARequest, PromptAdapterRequest]]: try: adapters = await validate_adapters( request=request, adapter_store=self.adapter_store)