Skip to content

Commit

Permalink
tests: Update unit test and add docstring for tensorize_lora_adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
sangstar committed Dec 11, 2024
1 parent 480404e commit f7c0f8c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 34 deletions.
55 changes: 21 additions & 34 deletions tests/tensorizer_loader/test_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
from huggingface_hub import snapshot_download
from tensorizer import EncryptionParams

from vllm import SamplingParams
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.lora.request import LoRARequest
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer,
is_vllm_tensorized,
load_with_tensorizer,
open_stream,
serialize_vllm_model,
tensorize_vllm_model)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, TensorSerializer, is_vllm_tensorized,
load_with_tensorizer, open_stream, serialize_vllm_model,
tensorize_lora_adapter, tensorize_vllm_model)
# yapf: enable
from vllm.utils import import_from_path

Expand Down Expand Up @@ -347,32 +345,23 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):


def test_serialize_and_deserialize_lora(tmp_path):
import shutil

from safetensors.torch import load_file

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

sql_lora_files = snapshot_download(
repo_id="yard1/llama-2-7b-sql-lora-test")
tensor_path = os.path.join(sql_lora_files, "adapter_model.safetensors")
config_path = os.path.join(sql_lora_files, "adapter_config.json")
tensors = load_file(tensor_path)

# TODO: This will not work for non-local saving. Use `open_stream`
# to save this json. Pretty sure there are examples of this
# in the tensorizer repo
shutil.copy(config_path, tmp_path / "adapter_config.json")
model_ref = "meta-llama/Llama-2-7b-hf"
lora_path = "yard1/llama-2-7b-sql-lora-test"
model_uri = tmp_path / (model_ref + ".tensors")
tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri))
args = EngineArgs(model=model_ref)

tensorizer_uri = tmp_path / "adapter_model.tensors"
serializer = TensorSerializer(tensorizer_uri)
serializer.write_state_dict(tensors)
serializer.close()
tensorize_lora_adapter(lora_path, tensorizer_config)
tensorize_vllm_model(args, tensorizer_config)

# Now, load it
gc.collect()
torch.cuda.empty_cache()

loaded_vllm_model = LLM(model="meta-llama/Llama-2-7b-hf", enable_lora=True)
loaded_vllm_model = LLM(model="meta-llama/Llama-2-7b-hf",
load_format="tensorizer",
model_loader_extra_config=tensorizer_config,
enable_lora=True)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
Expand All @@ -382,13 +371,11 @@ def test_serialize_and_deserialize_lora(tmp_path):
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
]

lora_path = tmp_path
lora_tensorizer_uri = str(tmp_path) + "/model.tensors"
lora_path = tensorizer_config.tensorizer_dir
loaded_vllm_model.generate(prompts,
sampling_params,
lora_request=LoRARequest(
"sql-lora",
1,
lora_path,
tensorizer_config=TensorizerConfig(
tensorizer_uri=lora_tensorizer_uri, )))
tensorizer_config=tensorizer_config))
8 changes: 8 additions & 0 deletions vllm/model_executor/model_loader/tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,14 @@ def tensorize_vllm_model(engine_args: EngineArgs,

def tensorize_lora_adapter(lora_path: str,
tensorizer_config: TensorizerConfig):
"""
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
needed to load a LoRA adapter are a safetensors-format file called
adapter_model.safetensors and a json config file called adapter_config.json.
Serializes the files in the same directory as model tensors located at
tensorizer_config.tensorizer_uri.
"""
lora_files = snapshot_download(repo_id=lora_path)

# Current LoRA loading logic in
Expand Down

0 comments on commit f7c0f8c

Please sign in to comment.