17
17
from vllm .sampling_params import SamplingParams
18
18
from vllm .utils import random_uuid
19
19
20
+ from vllm .lora .request import LoRARequest
21
+ from vllm .entrypoints .openai .api_server import LoRAParserAction
22
+
20
23
TIMEOUT_KEEP_ALIVE = 5 # seconds.
21
24
app = FastAPI ()
22
25
engine = None
26
+ adapters = {}
23
27
24
28
25
29
@app .get ("/health" )
@@ -34,19 +38,29 @@ async def generate(request: Request) -> Response:
34
38
35
39
The request should be a JSON object with the following fields:
36
40
- prompt: the prompt to use for the generation.
41
+ - adapter: name of the LoRA adapter to be used.
37
42
- stream: whether to stream the results or not.
38
43
- other fields: the sampling parameters (See `SamplingParams` for details).
39
44
"""
40
45
request_dict = await request .json ()
41
46
prompt = request_dict .pop ("prompt" )
47
+ adapter = request_dict .pop ("adapter" , None )
42
48
prefix_pos = request_dict .pop ("prefix_pos" , None )
43
49
stream = request_dict .pop ("stream" , False )
44
50
sampling_params = SamplingParams (** request_dict )
45
51
request_id = random_uuid ()
46
52
53
+ if not adapter :
54
+ lora_request = None
55
+ elif adapter not in adapters :
56
+ raise ValueError (f"{ adapter } not a valid adapter in this service" )
57
+ else :
58
+ lora_request = adapters [adapter ]
59
+
47
60
results_generator = engine .generate (prompt ,
48
61
sampling_params ,
49
62
request_id ,
63
+ lora_request = lora_request ,
50
64
prefix_pos = prefix_pos )
51
65
52
66
# Streaming case
@@ -89,11 +103,27 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
89
103
type = str ,
90
104
default = None ,
91
105
help = "FastAPI root_path when app is behind a path based routing proxy" )
106
+ parser .add_argument (
107
+ "--lora-modules" ,
108
+ type = str ,
109
+ default = None ,
110
+ nargs = '+' ,
111
+ action = LoRAParserAction ,
112
+ help =
113
+ "LoRA module configurations in the format name=path. Multiple modules can be specified."
114
+ )
92
115
parser = AsyncEngineArgs .add_cli_args (parser )
93
116
args = parser .parse_args ()
94
117
95
118
engine_args = AsyncEngineArgs .from_cli_args (args )
96
119
engine = AsyncLLMEngine .from_engine_args (engine_args )
120
+ adapters = {
121
+ lora .name : LoRARequest (
122
+ lora_name = lora .name ,
123
+ lora_int_id = i ,
124
+ lora_local_path = lora .local_path ,
125
+ ) for i , lora in enumerate (args .lora_modules , start = 1 )
126
+ } if args .enable_lora else {}
97
127
98
128
app .root_path = args .root_path
99
129
uvicorn .run (app ,
0 commit comments