|
4 | 4 | import time
|
5 | 5 | import traceback
|
6 | 6 | from dataclasses import dataclass, field
|
7 |
| -from typing import List, Optional |
| 7 | +from typing import List, Optional, Union |
8 | 8 |
|
9 | 9 | import aiohttp
|
| 10 | +import huggingface_hub.constants |
10 | 11 | from tqdm.asyncio import tqdm
|
| 12 | +from transformers import (AutoTokenizer, PreTrainedTokenizer, |
| 13 | + PreTrainedTokenizerFast) |
11 | 14 |
|
12 | 15 | AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
13 | 16 |
|
@@ -388,6 +391,30 @@ def remove_prefix(text: str, prefix: str) -> str:
|
388 | 391 | return text
|
389 | 392 |
|
390 | 393 |
|
| 394 | +def get_model(pretrained_model_name_or_path: str): |
| 395 | + if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': |
| 396 | + from modelscope import snapshot_download |
| 397 | + else: |
| 398 | + from huggingface_hub import snapshot_download |
| 399 | + |
| 400 | + model_path = snapshot_download( |
| 401 | + model_id=pretrained_model_name_or_path, |
| 402 | + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, |
| 403 | + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) |
| 404 | + return model_path |
| 405 | + |
| 406 | + |
| 407 | +def get_tokenizer( |
| 408 | + pretrained_model_name_or_path: str, trust_remote_code: bool |
| 409 | +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: |
| 410 | + if pretrained_model_name_or_path is not None and not os.path.exists( |
| 411 | + pretrained_model_name_or_path): |
| 412 | + pretrained_model_name_or_path = get_model( |
| 413 | + pretrained_model_name_or_path) |
| 414 | + return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, |
| 415 | + trust_remote_code=trust_remote_code) |
| 416 | + |
| 417 | + |
391 | 418 | ASYNC_REQUEST_FUNCS = {
|
392 | 419 | "tgi": async_request_tgi,
|
393 | 420 | "vllm": async_request_openai_completions,
|
|
0 commit comments