Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support serializing and deserializing LoRA adapters #3

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ steps:
- python3 llm_engine_example.py
- python3 offline_inference_vision_language.py
- python3 offline_inference_vision_language_multi_image.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 tensorize_vllm_model.py --model meta-llama/Llama-2-7b-hf --lora-path yard1/llama-2-7b-sql-lora-test serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model meta-llama/Llama-2-7b-hf --lora-path yard1/llama-2-7b-sql-lora-test deserialize --path-to-tensors /tmp/vllm/meta-llama/Llama-2-7b-hf/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- python3 offline_profile.py --model facebook/opt-125m

Expand Down
67 changes: 59 additions & 8 deletions examples/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig,
tensorize_vllm_model)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerArgs, TensorizerConfig, tensorize_lora_adapter,
tensorize_vllm_model)
from vllm.utils import FlexibleArgumentParser

# yapf conflicts with isort for this docstring
Expand Down Expand Up @@ -105,6 +105,17 @@ def parse_args():
"also supported, although libsodium must be installed to "
"use it.")
parser = EngineArgs.add_cli_args(parser)

parser.add_argument(
"--lora-path",
type=str,
required=False,
help="Path to a LoRA adapter to "
"serialize along with model tensors. This can then be deserialized "
"along with the model by passing a tensorizer_config kwarg to "
"LoRARequest with type TensorizerConfig."
)

subparsers = parser.add_subparsers(dest='command')

serialize_parser = subparsers.add_parser(
Expand Down Expand Up @@ -167,11 +178,44 @@ def parse_args():


def deserialize():
llm = LLM(model=args.model,
load_format="tensorizer",
tensor_parallel_size=args.tensor_parallel_size,
model_loader_extra_config=tensorizer_config
)
if args.lora_path:
from vllm import SamplingParams
from vllm.lora.request import LoRARequest

llm = LLM(model=args.model,
load_format="tensorizer",
tensor_parallel_size=args.tensor_parallel_size,
model_loader_extra_config=tensorizer_config,
enable_lora=True
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=256,
stop=["[/assistant]"]
)

# Truncating this as the extra text isn't necessary
prompts = [
"[user] Write a SQL query to answer the question based on ..."
]

# Test LoRA load
print(
llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest("sql-lora",
1,
lora_path,
tensorizer_config = tensorizer_config)
)
)
else:
llm = LLM(model=args.model,
load_format="tensorizer",
tensor_parallel_size=args.tensor_parallel_size,
model_loader_extra_config=tensorizer_config
)
return llm


Expand All @@ -185,6 +229,8 @@ def deserialize():
s3_endpoint = (getattr(args, 's3_endpoint', None)
or os.environ.get("S3_ENDPOINT_URL", None))

lora_path = args.lora_path

credentials = {
"s3_access_key_id": s3_access_key_id,
"s3_secret_access_key": s3_secret_access_key,
Expand Down Expand Up @@ -221,11 +267,16 @@ def deserialize():
else:
model_path = f"{base_path}/model.tensors"

os.makedirs(base_path, exist_ok=True)

tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=keyfile,
**credentials)

if lora_path:
tensorize_lora_adapter(lora_path, tensorizer_config)

tensorize_vllm_model(engine_args, tensorizer_config)

elif args.command == "deserialize":
Expand Down
113 changes: 113 additions & 0 deletions tests/entrypoints/openai/test_tensorizer_entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import json
import tempfile
import weakref

import openai
import pytest
import pytest_asyncio
import torch.cuda

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, tensorize_lora_adapter, tensorize_vllm_model)

from ...utils import RemoteOpenAIServer

MODEL_NAME = "meta-llama/Llama-2-7b-hf"
LORA_PATH = "yard1/llama-2-7b-sql-lora-test"


@pytest.fixture(scope="module")
def tmp_dir():
tmp_dir = tempfile.TemporaryDirectory()

def cleanup():
tmp_dir.cleanup()

weakref.finalize(tmp_dir, cleanup)

yield tmp_dir

cleanup()


@pytest.fixture(scope="module")
def tensorize_model_and_lora(tmp_dir):
model_uri = tmp_dir.name + "/model.tensors"
tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri)
args = EngineArgs(model=MODEL_NAME, )

tensorize_lora_adapter(LORA_PATH, tensorizer_config)
tensorize_vllm_model(args, tensorizer_config)

torch.cuda.empty_cache()


@pytest.fixture(scope="module")
def server(tmp_dir, tensorize_model_and_lora):
model_uri = tmp_dir.name + "/model.tensors"
model_loader_extra_config = {
"tensorizer_uri": model_uri,
}

## Start OpenAI API server
args = [
"--load-format", "tensorizer", "--model-loader-extra-config",
json.dumps(model_loader_extra_config), "--enable-lora"
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.model == MODEL_NAME
assert len(completion.choices) == 1
assert len(completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)


def test_confirm_deserialize_and_serve(tmp_dir, tensorize_model_and_lora):
model_uri = tmp_dir.name + "/model.tensors"
llm = LLM(
MODEL_NAME,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_uri),
enable_lora=True)

sampling_params = SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])

prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
]

llm.generate(prompts,
sampling_params,
lora_request=LoRARequest("sql-lora",
1,
tmp_dir.name,
tensorizer_config=TensorizerConfig(
tensorizer_uri=tmp_dir.name +
"/adapter_model.tensors")))
51 changes: 43 additions & 8 deletions tests/tensorizer_loader/test_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
from huggingface_hub import snapshot_download
from tensorizer import EncryptionParams

from vllm import SamplingParams
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.lora.request import LoRARequest
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer,
is_vllm_tensorized,
load_with_tensorizer,
open_stream,
serialize_vllm_model,
tensorize_vllm_model)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, TensorSerializer, is_vllm_tensorized,
load_with_tensorizer, open_stream, serialize_vllm_model,
tensorize_lora_adapter, tensorize_vllm_model)
# yapf: enable
from vllm.utils import import_from_path

Expand Down Expand Up @@ -344,3 +342,40 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
# noqa: E501

assert outputs == deserialized_outputs


def test_serialize_and_deserialize_lora(tmp_path):

model_ref = "meta-llama/Llama-2-7b-hf"
lora_path = "yard1/llama-2-7b-sql-lora-test"
model_uri = tmp_path / (model_ref + ".tensors")
tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri))
args = EngineArgs(model=model_ref)

tensorize_lora_adapter(lora_path, tensorizer_config)
tensorize_vllm_model(args, tensorizer_config)

gc.collect()
torch.cuda.empty_cache()

loaded_vllm_model = LLM(model="meta-llama/Llama-2-7b-hf",
load_format="tensorizer",
model_loader_extra_config=tensorizer_config,
enable_lora=True)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])

prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
]

lora_path = tensorizer_config.tensorizer_dir
loaded_vllm_model.generate(prompts,
sampling_params,
lora_request=LoRARequest(
"sql-lora",
1,
lora_path,
tensorizer_config=tensorizer_config))
Loading
Loading