From d15299329884e74aa16d5e789da4db4c82845aef Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 27 Jun 2024 12:34:25 -0700 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Prompt=20adapter=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- src/vllm_tgis_adapter/__main__.py | 12 +++++++- src/vllm_tgis_adapter/grpc/adapters.py | 34 +++++++++++++++++------ src/vllm_tgis_adapter/grpc/grpc_server.py | 34 +++++++++++++++++++---- src/vllm_tgis_adapter/tgis_utils/args.py | 4 +++ 4 files changed, 69 insertions(+), 15 deletions(-) diff --git a/src/vllm_tgis_adapter/__main__.py b/src/vllm_tgis_adapter/__main__.py index cf54695..72fd7e3 100644 --- a/src/vllm_tgis_adapter/__main__.py +++ b/src/vllm_tgis_adapter/__main__.py @@ -202,8 +202,18 @@ async def run_http_server( args.lora_modules, args.chat_template, ) + + kwargs = {} + # prompt adapter arg required for vllm >0.5.1 + if hasattr(args, "prompt_adapters"): + kwargs = {"prompt_adapters": args.prompt_adapters} + openai_serving_completion = OpenAIServingCompletion( - engine, model_config, served_model_names, args.lora_modules + engine, + model_config, + served_model_names, + args.lora_modules, + **kwargs, ) openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, served_model_names diff --git a/src/vllm_tgis_adapter/grpc/adapters.py b/src/vllm_tgis_adapter/grpc/adapters.py index 321968c..e56def4 100644 --- a/src/vllm_tgis_adapter/grpc/adapters.py +++ b/src/vllm_tgis_adapter/grpc/adapters.py @@ -14,14 +14,16 @@ from pathlib import Path from typing import TYPE_CHECKING +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest + +from .validation import TGISValidationError + if TYPE_CHECKING: from vllm.entrypoints.grpc.pb.generation_pb2 import ( BatchedGenerationRequest, SingleGenerationRequest, ) -from vllm.lora.request import LoRARequest - -from .validation import TGISValidationError global_thread_pool = None # used for loading adapter files from disk @@ -33,6 +35,7 @@ class AdapterMetadata: unique_id: int # Unique integer for vllm to identify the adapter adapter_type: str # The string name of the peft adapter type, e.g. LORA full_path: str + full_config: dict # The loaded adapter_config.json dict @dataclasses.dataclass @@ -45,7 +48,7 @@ class AdapterStore: async def validate_adapters( request: SingleGenerationRequest | BatchedGenerationRequest, adapter_store: AdapterStore | None, -) -> dict[str, LoRARequest]: +) -> dict[str, LoRARequest | PromptAdapterRequest]: """Validate the adapters. Takes the adapter name from the request and constructs a valid @@ -56,6 +59,9 @@ async def validate_adapters( """ global global_thread_pool # noqa: PLW0603 adapter_id = request.adapter_id + # Backwards compatibility for `prefix_id` arg + if not adapter_id and request.prefix_id: + adapter_id = request.prefix_id if adapter_id and not adapter_store: TGISValidationError.AdaptersDisabled.error() @@ -73,18 +79,20 @@ async def validate_adapters( if global_thread_pool is None: global_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) - adapter_type = await loop.run_in_executor( + adapter_config = await loop.run_in_executor( global_thread_pool, - _get_adapter_type_from_file, + _load_adapter_config_from_file, adapter_id, local_adapter_path, ) + adapter_type = adapter_config.get("peft_type", None) # Add to cache adapter_metadata = AdapterMetadata( unique_id=adapter_store.next_unique_id, adapter_type=adapter_type, full_path=local_adapter_path, + full_config=adapter_config, ) adapter_store.adapters[adapter_id] = adapter_metadata @@ -96,12 +104,22 @@ async def validate_adapters( lora_local_path=adapter_metadata.full_path, ) return {"lora_request": lora_request} + if adapter_metadata.adapter_type == "PROMPT_TUNING": + prompt_adapter_request = PromptAdapterRequest( + prompt_adapter_id=adapter_metadata.unique_id, + prompt_adapter_name=adapter_id, + prompt_adapter_local_path=adapter_metadata.full_path, + prompt_adapter_num_virtual_tokens=adapter_metadata.full_config.get( + "num_virtual_tokens", 0 + ), + ) + return {"prompt_adapter_request": prompt_adapter_request} # All other types unsupported TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) # noqa: RET503 -def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str: +def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> dict: """Get adapter from file. Performs all the filesystem access required to deduce the type @@ -123,7 +141,7 @@ def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str: with open(adapter_config_path) as adapter_config_file: adapter_config = json.load(adapter_config_file) - return adapter_config.get("peft_type", None) + return adapter_config def _reject_bad_adapter_id(adapter_id: str) -> None: diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index bc5366c..af309b9 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -41,7 +41,6 @@ TGISStatLogger, ) -from .adapters import AdapterStore, validate_adapters from .pb import generation_pb2_grpc from .pb.generation_pb2 import DESCRIPTOR as _GENERATION_DESCRIPTOR from .pb.generation_pb2 import ( @@ -56,6 +55,14 @@ ) from .validation import validate_input, validate_params +try: + from .adapters import AdapterStore, validate_adapters +except ImportError: + adapters_available = False +else: + adapters_available = True + + if TYPE_CHECKING: import argparse from collections.abc import AsyncIterator, MutableSequence @@ -76,6 +83,11 @@ SingleGenerationRequest, ) + try: + from .adapters import PromptAdapterRequest + except ImportError: + pass + _T = TypeVar("_T") _F = TypeVar("_F", Callable, Coroutine) @@ -170,9 +182,11 @@ def __init__( self.skip_special_tokens = not args.output_special_tokens self.default_include_stop_seqs = args.default_include_stop_seqs + # Backwards compatibility for TGIS: PREFIX_STORE_PATH + adapter_cache_path = args.adapter_cache or args.prefix_store_path self.adapter_store = ( - AdapterStore(cache_path=args.adapter_cache, adapters={}) - if args.adapter_cache + AdapterStore(cache_path=adapter_cache_path, adapters={}) + if adapter_cache_path else None ) self.health_servicer = health_servicer @@ -213,7 +227,11 @@ async def Generate( generators = [] max_is_token_limit = [False] * request_count - adapter_kwargs = await self._validate_adapters(request, context) + adapter_kwargs = ( + await self._validate_adapters(request, context) + if adapters_available + else {} + ) for i, req in enumerate(request.requests): input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize( @@ -309,7 +327,11 @@ async def GenerateStream( sampling_params, truncate_input_tokens, request.request.text, context ) - adapter_kwargs = await self._validate_adapters(request, context) + adapter_kwargs = ( + await self._validate_adapters(request, context) + if adapters_available + else {} + ) inputs = TextTokensPrompt( prompt=request.request.text, prompt_token_ids=input_ids ) @@ -577,7 +599,7 @@ async def _validate_adapters( self, request: SingleGenerationRequest | BatchedGenerationRequest, context: ServicerContext, - ) -> dict[str, LoRARequest]: + ) -> dict[str, LoRARequest | PromptAdapterRequest]: try: adapters = await validate_adapters( request=request, adapter_store=self.adapter_store diff --git a/src/vllm_tgis_adapter/tgis_utils/args.py b/src/vllm_tgis_adapter/tgis_utils/args.py index 9522ac3..d0b79a8 100644 --- a/src/vllm_tgis_adapter/tgis_utils/args.py +++ b/src/vllm_tgis_adapter/tgis_utils/args.py @@ -116,6 +116,10 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--tls-client-ca-cert-path", type=str) # add a path when peft adapters will be loaded from parser.add_argument("--adapter-cache", type=str) + # backwards-compatibility support for tgis prompt tuning + parser.add_argument( + "--prefix-store-path", type=str, help="Deprecated, use --adapter-cache" + ) # TODO check/add other args here