Skip to content

Commit 5166259

Browse files
fix: multi-lora with sample api-server
1 parent 71bcaf9 commit 5166259

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

docs/source/models/lora.rst

+27-4
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,36 @@ the third parameter is the path to the LoRA adapter.
5151
Check out `examples/multilora_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py>`_
5252
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
5353

54-
Serving LoRA Adapters
54+
Serving LoRA Adapters (Sample Service)
55+
--------------------------------------
56+
The sample service entrypoint can be used to serve LoRA modules. To do so, we use
57+
``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server:
58+
59+
.. code-block:: bash
60+
python -m vllm.entrypoints.api_server \
61+
--model meta-llama/Llama-2-7b-hf \
62+
--lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
63+
64+
This will start a fast-api server that accepts requests. An example is as follows:
65+
66+
.. code-block:: bash
67+
curl http://localhost:8000/generate -H "Content-Type: application/json" -d '{
68+
"prompt": "San Francisco is a",
69+
"max_tokens": 7,
70+
"temperature": 1,
71+
"adapter": "sql-lora"
72+
}'
73+
74+
Note that if the `adapter` parameter is not included, the responses will be from the base model only.
75+
The `adapter` is expected to be the string corresponding to one of the adapter name passed with `lora-modules`.
76+
77+
Serving LoRA Adapters
5578
---------------------
56-
LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use
57-
``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server:
79+
LoRA adapted models can also be served with the Open-AI compatible vLLM server:
5880

5981
.. code-block:: bash
6082
61-
python -m vllm.entrypoints.api_server \
83+
python -m vllm.entrypoints.openai.api_server \
6284
--model meta-llama/Llama-2-7b-hf \
6385
--enable-lora \
6486
--lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
@@ -89,3 +111,4 @@ with its base model:
89111
Requests can specify the LoRA adapter as if it were any other model via the ``model`` request parameter. The requests will be
90112
processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other
91113
LoRA adapter requests if they were provided and ``max_loras`` is set high enough).
114+

vllm/entrypoints/api_server.py

+30
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
from vllm.sampling_params import SamplingParams
1818
from vllm.utils import random_uuid
1919

20+
from vllm.lora.request import LoRARequest
21+
from vllm.entrypoints.openai.api_server import LoRAParserAction
22+
2023
TIMEOUT_KEEP_ALIVE = 5 # seconds.
2124
app = FastAPI()
2225
engine = None
26+
adapters = {}
2327

2428

2529
@app.get("/health")
@@ -34,19 +38,29 @@ async def generate(request: Request) -> Response:
3438
3539
The request should be a JSON object with the following fields:
3640
- prompt: the prompt to use for the generation.
41+
- adapter: name of the LoRA adapter to be used.
3742
- stream: whether to stream the results or not.
3843
- other fields: the sampling parameters (See `SamplingParams` for details).
3944
"""
4045
request_dict = await request.json()
4146
prompt = request_dict.pop("prompt")
47+
adapter = request_dict.pop("adapter", None)
4248
prefix_pos = request_dict.pop("prefix_pos", None)
4349
stream = request_dict.pop("stream", False)
4450
sampling_params = SamplingParams(**request_dict)
4551
request_id = random_uuid()
4652

53+
if not adapter:
54+
lora_request = None
55+
elif adapter not in adapters:
56+
raise ValueError(f"{adapter} not a valid adapter in this service")
57+
else:
58+
lora_request = adapters[adapter]
59+
4760
results_generator = engine.generate(prompt,
4861
sampling_params,
4962
request_id,
63+
lora_request=lora_request,
5064
prefix_pos=prefix_pos)
5165

5266
# Streaming case
@@ -89,11 +103,27 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
89103
type=str,
90104
default=None,
91105
help="FastAPI root_path when app is behind a path based routing proxy")
106+
parser.add_argument(
107+
"--lora-modules",
108+
type=str,
109+
default=None,
110+
nargs='+',
111+
action=LoRAParserAction,
112+
help=
113+
"LoRA module configurations in the format name=path. Multiple modules can be specified."
114+
)
92115
parser = AsyncEngineArgs.add_cli_args(parser)
93116
args = parser.parse_args()
94117

95118
engine_args = AsyncEngineArgs.from_cli_args(args)
96119
engine = AsyncLLMEngine.from_engine_args(engine_args)
120+
adapters = {
121+
lora.name: LoRARequest(
122+
lora_name=lora.name,
123+
lora_int_id=i,
124+
lora_local_path=lora.local_path,
125+
) for i, lora in enumerate(args.lora_modules, start=1)
126+
} if args.enable_lora else {}
97127

98128
app.root_path = args.root_path
99129
uvicorn.run(app,

0 commit comments

Comments
 (0)