This repository was archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy patharg_utils.py
321 lines (309 loc) · 14.5 KB
/
arg_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, LoRAConfig)
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model: str
tokenizer: Optional[str] = None
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
dtype: str = 'auto'
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
sparsity: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine."""
# NOTE: If you update any of the arguments below, please also
# make sure to update docs/source/models/engine_args.rst
# Model arguments
parser.add_argument(
'--model',
type=str,
default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument(
'--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument(
'--revision',
type=str,
default=None,
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--tokenizer-revision',
type=str,
default=None,
help='the specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
parser.add_argument(
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
help='The format of the model weights to load. '
'"auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available. '
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
'which is mainly for profiling.')
parser.add_argument(
'--dtype',
type=str,
default=EngineArgs.dtype,
choices=[
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--max-model-len',
type=int,
default=None,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
parser.add_argument(
'--max-parallel-loading-workers',
type=int,
help='load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
'parallel and large models')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='random seed')
parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument(
'--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--max-paddings',
type=int,
default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch')
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', 'gptq', 'squeezellm', None],
default=None,
help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
'attribute in the model config file. If that is '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument(
'--sparsity',
'-s',
type=str,
choices=['sparse_w16a16', None],
default=None,
help='Method used to compress sparse weights. If '
'None, we first check the `sparsity_config` attribute '
'in the model config file. If that is None we assume '
'the model weights are dense')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
'will use eager mode and CUDA graph in hybrid '
'for maximal performance and flexibility.')
parser.add_argument('--max-context-len-to-capture',
type=int,
default=EngineArgs.max_context_len_to_capture,
help='maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
help='Max number of LoRAs in a single batch.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
'--lora-dtype',
type=str,
default=EngineArgs.lora_dtype,
choices=['auto', 'float16', 'bfloat16', 'float32'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
default=EngineArgs.max_cpu_loras,
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
Optional[LoRAConfig]]:
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.tokenizer_revision,
self.max_model_len, self.quantization, self.sparsity,
self.enforce_eager, self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space,
model_config.get_sliding_window())
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray,
self.max_parallel_loading_workers)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.max_paddings)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
return model_config, cache_config, parallel_config, scheduler_config, lora_config
@dataclass
class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False
max_log_len: Optional[int] = None
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
action='store_true',
help='use Ray to start the LLM engine in a '
'separate process as the server process.')
parser.add_argument('--disable-log-requests',
action='store_true',
help='disable logging requests')
parser.add_argument('--max-log-len',
type=int,
default=None,
help='max number of prompt characters or prompt '
'ID numbers being printed in log. '
'Default: unlimited.')
return parser