Skip to content

Commit

Permalink
[Core] Support load and unload LoRA in api server (vllm-project#6566)
Browse files Browse the repository at this point in the history
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
  • Loading branch information
Jeffwan and jeejeelee authored Sep 6, 2024
1 parent 2febcf2 commit db3bf7c
Show file tree
Hide file tree
Showing 10 changed files with 336 additions and 6 deletions.
1 change: 0 additions & 1 deletion docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
52 changes: 52 additions & 0 deletions docs/source/models/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,55 @@ The following is an example request
"max_tokens": 7,
"temperature": 0
}' | jq
Dynamically serving LoRA Adapters
---------------------------------

In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
to change models on-the-fly is needed.

Note: Enabling this feature in production environments is risky as user may participate model adapter management.

To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.

.. code-block:: bash
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
Loading a LoRA Adapter:

To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
details of the adapter to be loaded. The request payload should include the name and path to the LoRA adapter.

Example request to load a LoRA adapter:

.. code-block:: bash
curl -X POST http://localhost:8000/v1/load_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "sql_adapter",
"lora_path": "/path/to/sql-lora-adapter"
}'
Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter
cannot be found or loaded, an appropriate error message will be returned.

Unloading a LoRA Adapter:

To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint
with the name or ID of the adapter to be unloaded.

Example request to unload a LoRA adapter:

.. code-block:: bash
curl -X POST http://localhost:8000/v1/unload_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "sql_adapter"
}'
2 changes: 1 addition & 1 deletion tests/entrypoints/llm/test_generate_multiple_loras.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def zephyr_lora_files():
@pytest.mark.skip_global_cleanup
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
lora_request = [
LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files)
for idx in range(len(PROMPTS))
]
# Multiple SamplingParams should be matched with each prompt
Expand Down
107 changes: 107 additions & 0 deletions tests/entrypoints/openai/test_serving_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest

from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import OpenAIServing

MODEL_NAME = "meta-llama/Llama-2-7b"
LORA_LOADING_SUCCESS_MESSAGE = (
"Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
"Success: LoRA adapter '{lora_name}' removed successfully.")


async def _async_serving_engine_init():
mock_engine_client = MagicMock(spec=AsyncEngineClient)
mock_model_config = MagicMock(spec=ModelConfig)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config.max_model_len = 2048

serving_engine = OpenAIServing(mock_engine_client,
mock_model_config,
served_model_names=[MODEL_NAME],
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_engine


@pytest.mark.asyncio
async def test_load_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="adapter",
lora_path="/path/to/adapter2")
response = await serving_engine.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
assert len(serving_engine.lora_requests) == 1
assert serving_engine.lora_requests[0].lora_name == "adapter"


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
response = await serving_engine.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_engine.lora_requests) == 1

request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
assert len(serving_engine.lora_requests) == 1


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert len(serving_engine.lora_requests) == 1

request = UnloadLoraAdapterRequest(lora_name="adapter1")
response = await serving_engine.unload_lora_adapter(request)
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_engine.lora_requests) == 0


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
serving_engine = await _async_serving_engine_init()
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
response = await serving_engine.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
serving_engine = await _async_serving_engine_init()
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_engine.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
40 changes: 38 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
DetokenizeResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
LoadLoraAdapterRequest,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
TokenizeResponse,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
Expand Down Expand Up @@ -343,6 +345,40 @@ async def stop_profile():
return Response(status_code=200)


if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")

@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest):
response = await openai_serving_chat.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

response = await openai_serving_completion.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

return Response(status_code=200, content=response)

@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
response = await openai_serving_chat.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

response = await openai_serving_completion.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

return Response(status_code=200, content=response)


def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,3 +878,13 @@ class DetokenizeRequest(OpenAIBaseModel):

class DetokenizeResponse(OpenAIBaseModel):
prompt: str


class LoadLoraAdapterRequest(BaseModel):
lora_name: str
lora_path: str


class UnloadLoraAdapterRequest(BaseModel):
lora_name: str
lora_int_id: Optional[int] = Field(default=None)
79 changes: 78 additions & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeRequest)
TokenizeRequest,
UnloadLoraAdapterRequest)
# yapf: enable
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
Expand All @@ -32,6 +34,7 @@
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter

logger = init_logger(__name__)

Expand Down Expand Up @@ -78,6 +81,7 @@ def __init__(

self.served_model_names = served_model_names

self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
Expand Down Expand Up @@ -403,3 +407,76 @@ def _get_decoded_token(logprob: Logprob,
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)

async def _check_load_lora_adapter_request(
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return self.create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

# Check if the lora adapter with the given name already exists
if any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' has already been"
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

return None

async def _check_unload_lora_adapter_request(
self,
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id:
return self.create_error_response(
message=
"either 'lora_name' and 'lora_int_id' needs to be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

# Check if the lora adapter with the given name exists
if not any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

return None

async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret

lora_name, lora_path = request.lora_name, request.lora_path
unique_id = self.lora_id_counter.inc(1)
self.lora_requests.append(
LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path))
return f"Success: LoRA adapter '{lora_name}' added successfully."

async def unload_lora_adapter(
self,
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_unload_lora_adapter_request(request
)
if error_check_ret is not None:
return error_check_ret

lora_name = request.lora_name
self.lora_requests = [
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -409,6 +410,12 @@ def get_default_config_root():
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),

# If set, allow loading or unloading lora adapters in runtime,
"VLLM_ALLOW_RUNTIME_LORA_UPDATING":
lambda:
(os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in
("1", "true")),
}

# end-env-vars-definition
Expand Down
Loading

0 comments on commit db3bf7c

Please sign in to comment.