Skip to content

Commit

Permalink
✨ pipe prompt reqs in server
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
  • Loading branch information
joerunde committed Jun 4, 2024
1 parent 0d65cc6 commit 84b9084
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
2 changes: 2 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,8 @@ async def generate(
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: PromptAdapterRequest to use for generation,
if any.
Yields:
The output `RequestOutput` objects from the LLMEngine
Expand Down
33 changes: 24 additions & 9 deletions vllm/entrypoints/grpc/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SingleGenerationRequest)
from vllm.entrypoints.grpc.validation import TGISValidationError
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest

global_thread_pool = None # used for loading adapter files from disk

Expand All @@ -20,6 +21,7 @@ class AdapterMetadata:
unique_id: int # Unique integer for vllm to identify the adapter
adapter_type: str # The string name of the peft adapter type, e.g. LORA
full_path: str
full_config: Dict # The loaded adapter_config.json dict


@dataclasses.dataclass
Expand All @@ -30,8 +32,9 @@ class AdapterStore:


async def validate_adapters(
request: Union[SingleGenerationRequest, BatchedGenerationRequest],
adapter_store: Optional[AdapterStore]) -> Dict[str, LoRARequest]:
request: Union[SingleGenerationRequest, BatchedGenerationRequest],
adapter_store: Optional[AdapterStore]
) -> Dict[str, Union[LoRARequest, PromptAdapterRequest]]:
"""Takes the adapter name from the request and constructs a valid
engine request if one is set. Raises if the requested adapter
does not exist or adapter type is unsupported
Expand All @@ -40,6 +43,9 @@ async def validate_adapters(
"""
global global_thread_pool
adapter_id = request.adapter_id
# Backwards compatibility for `prefix_id` arg
if not adapter_id and request.prefix_id:
adapter_id = request.prefix_id

if adapter_id and not adapter_store:
TGISValidationError.AdaptersDisabled.error()
Expand All @@ -57,16 +63,17 @@ async def validate_adapters(
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2)

adapter_type = await loop.run_in_executor(global_thread_pool,
_get_adapter_type_from_file,
adapter_id,
local_adapter_path)
adapter_config = await (loop.run_in_executor(
global_thread_pool, _load_adapter_config_from_file, adapter_id,
local_adapter_path))
adapter_type = adapter_config.get("peft_type", None)

# Add to cache
adapter_metadata = AdapterMetadata(
unique_id=adapter_store.next_unique_id,
adapter_type=adapter_type,
full_path=local_adapter_path)
full_path=local_adapter_path,
full_config=adapter_config)
adapter_store.adapters[adapter_id] = adapter_metadata

# Build the proper vllm request object
Expand All @@ -75,12 +82,20 @@ async def validate_adapters(
lora_int_id=adapter_metadata.unique_id,
lora_local_path=adapter_metadata.full_path)
return {"lora_request": lora_request}
elif adapter_metadata.adapter_type == "PROMPT_TUNING":
prompt_adapter_request = PromptAdapterRequest(
prompt_adapter_id=adapter_metadata.unique_id,
prompt_adapter_name=adapter_id,
prompt_adapter_local_path=adapter_metadata.full_path,
prompt_adapter_num_virtual_tokens=adapter_metadata.full_config.get(
"num_virtual_tokens", 0))
return {"prompt_adapter_request": prompt_adapter_request}

# All other types unsupported
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type)


def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> Dict:
"""This function does all the filesystem access required to deduce the type
of the adapter. It's run in a separate thread pool executor so that file
access does not block the main event loop."""
Expand All @@ -97,4 +112,4 @@ def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
with open(adapter_config_path) as adapter_config_file:
adapter_config = json.load(adapter_config_file)

return adapter_config.get("peft_type", None)
return adapter_config
3 changes: 2 additions & 1 deletion vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from vllm.inputs import TextTokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import Logprob
from vllm.tgis_utils import logs
from vllm.tgis_utils.guided_decoding import (
Expand Down Expand Up @@ -459,7 +460,7 @@ async def _validate_adapters(self,
request: Union[SingleGenerationRequest,
BatchedGenerationRequest],
context: ServicerContext) \
-> Dict[str, LoRARequest]:
-> Dict[str, Union[LoRARequest, PromptAdapterRequest]]:
try:
adapters = await validate_adapters(
request=request, adapter_store=self.adapter_store)
Expand Down

0 comments on commit 84b9084

Please sign in to comment.