From 02a22e90e3e33514ab2d68b8617a668bb68fbf96 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Fri, 4 Oct 2024 10:38:25 -0700 Subject: [PATCH] [Core][VLM] Test registration for OOT multimodal models (#8717) Co-authored-by: DarkLight1337 --- docs/source/models/adding_model.rst | 18 +++++- find_cuda_init.py | 33 ++++++++++ tests/conftest.py | 30 ++++++++-- tests/entrypoints/openai/test_audio.py | 4 +- tests/entrypoints/openai/test_vision.py | 13 +++- tests/models/test_oot_registration.py | 38 ++++++++++++ .../vllm_add_dummy_model/__init__.py | 28 +++------ .../vllm_add_dummy_model/my_llava.py | 28 +++++++++ .../vllm_add_dummy_model/my_opt.py | 19 ++++++ vllm/engine/arg_utils.py | 2 + vllm/engine/llm_engine.py | 3 - vllm/model_executor/models/registry.py | 60 ++++++++++++++----- 12 files changed, 227 insertions(+), 49 deletions(-) create mode 100644 find_cuda_init.py create mode 100644 tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py create mode 100644 tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 1f220b723cacd..fa1003874033e 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -85,16 +85,16 @@ When it comes to the linear layers, we provide the following options to parallel * :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving. * :code:`RowParallelLinear`: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An *all-reduce* operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer. * :code:`ColumnParallelLinear`: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer. -* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple `ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices. +* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple :code:`ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices. * :code:`QKVParallelLinear`: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices. -Note that all the linear layers above take `linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization. +Note that all the linear layers above take :code:`linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization. 4. Implement the weight loading logic ------------------------------------- You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class. -This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for `MergedColumnParallelLinear` and `QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately. +This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for :code:`MergedColumnParallelLinear` and :code:`QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately. 5. Register your model ---------------------- @@ -114,6 +114,18 @@ Just add the following lines in your code: from your_code import YourModelForCausalLM ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) +If your model imports modules that initialize CUDA, consider instead lazy-importing it to avoid an error like :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`: + +.. code-block:: python + + from vllm import ModelRegistry + + ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCausalLM") + +.. important:: + If your model is a multimodal model, make sure the model class implements the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. + Read more about that :ref:`here `. + If you are running api server with :code:`vllm serve `, you can wrap the entrypoint with the following code: .. code-block:: python diff --git a/find_cuda_init.py b/find_cuda_init.py new file mode 100644 index 0000000000000..51db23102f9ac --- /dev/null +++ b/find_cuda_init.py @@ -0,0 +1,33 @@ +import importlib +import traceback +from typing import Callable +from unittest.mock import patch + + +def find_cuda_init(fn: Callable[[], object]) -> None: + """ + Helper function to debug CUDA re-initialization errors. + + If `fn` initializes CUDA, prints the stack trace of how this happens. + """ + from torch.cuda import _lazy_init + + stack = None + + def wrapper(): + nonlocal stack + stack = traceback.extract_stack() + return _lazy_init() + + with patch("torch.cuda._lazy_init", wrapper): + fn() + + if stack is not None: + print("==== CUDA Initialized ====") + print("".join(traceback.format_list(stack)).strip()) + print("==========================") + + +if __name__ == "__main__": + find_cuda_init( + lambda: importlib.import_module("vllm.model_executor.models.llava")) diff --git a/tests/conftest.py b/tests/conftest.py index 45dc5e8323ca4..b1833fdae5347 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -879,15 +879,16 @@ def num_gpus_available(): temp_dir = tempfile.gettempdir() -_dummy_path = os.path.join(temp_dir, "dummy_opt") +_dummy_opt_path = os.path.join(temp_dir, "dummy_opt") +_dummy_llava_path = os.path.join(temp_dir, "dummy_llava") @pytest.fixture def dummy_opt_path(): - json_path = os.path.join(_dummy_path, "config.json") - if not os.path.exists(_dummy_path): + json_path = os.path.join(_dummy_opt_path, "config.json") + if not os.path.exists(_dummy_opt_path): snapshot_download(repo_id="facebook/opt-125m", - local_dir=_dummy_path, + local_dir=_dummy_opt_path, ignore_patterns=[ "*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack" @@ -898,4 +899,23 @@ def dummy_opt_path(): config["architectures"] = ["MyOPTForCausalLM"] with open(json_path, "w") as f: json.dump(config, f) - return _dummy_path + return _dummy_opt_path + + +@pytest.fixture +def dummy_llava_path(): + json_path = os.path.join(_dummy_llava_path, "config.json") + if not os.path.exists(_dummy_llava_path): + snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf", + local_dir=_dummy_llava_path, + ignore_patterns=[ + "*.bin", "*.bin.index.json", "*.pt", "*.h5", + "*.msgpack" + ]) + assert os.path.exists(json_path) + with open(json_path, "r") as f: + config = json.load(f) + config["architectures"] = ["MyLlava"] + with open(json_path, "w") as f: + json.dump(config, f) + return _dummy_llava_path diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index a9a0ac012c8ff..df8a140283fbb 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -21,7 +21,9 @@ def server(): "--dtype", "bfloat16", "--max-model-len", - "4096", + "2048", + "--max-num-seqs", + "5", "--enforce-eager", ] diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index f61fa127b7d06..81d79601124a7 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -23,9 +23,16 @@ @pytest.fixture(scope="module") def server(): args = [ - "--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs", - "5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"image={MAXIMUM_IMAGES}" + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "5", + "--enforce-eager", + "--trust-remote-code", + "--limit-mm-per-prompt", + f"image={MAXIMUM_IMAGES}", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 5cb82a5ac4c7d..ee3f8911f318c 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -3,6 +3,7 @@ import pytest from vllm import LLM, SamplingParams +from vllm.assets.image import ImageAsset from ..utils import fork_new_process_for_each_test @@ -29,3 +30,40 @@ def test_oot_registration(dummy_opt_path): # make sure only the first token is generated rest = generated_text.replace(first_token, "") assert rest == "" + + +image = ImageAsset("cherry_blossom").pil_image.convert("RGB") + + +@fork_new_process_for_each_test +def test_oot_multimodal_registration(dummy_llava_path): + os.environ["VLLM_PLUGINS"] = "register_dummy_model" + prompts = [{ + "prompt": "What's in the image?", + "multi_modal_data": { + "image": image + }, + }, { + "prompt": "Describe the image", + "multi_modal_data": { + "image": image + }, + }] + + sampling_params = SamplingParams(temperature=0) + llm = LLM(model=dummy_llava_path, + load_format="dummy", + max_num_seqs=1, + trust_remote_code=True, + gpu_memory_utilization=0.98, + max_model_len=4096, + enforce_eager=True, + limit_mm_per_prompt={"image": 1}) + first_token = llm.get_tokenizer().decode(0) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + generated_text = output.outputs[0].text + # make sure only the first token is generated + rest = generated_text.replace(first_token, "") + assert rest == "" diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py index dcc0305e657ab..022ba66e38cc3 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py @@ -1,26 +1,14 @@ -from typing import Optional - -import torch - from vllm import ModelRegistry -from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata - - -class MyOPTForCausalLM(OPTForCausalLM): - - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: - # this dummy model always predicts the first token - logits = super().compute_logits(hidden_states, sampling_metadata) - if logits is not None: - logits.zero_() - logits[:, 0] += 1.0 - return logits def register(): - # register our dummy model + # Test directly passing the model + from .my_opt import MyOPTForCausalLM + if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs(): ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM) + + # Test passing lazy model + if "MyLlava" not in ModelRegistry.get_supported_archs(): + ModelRegistry.register_model("MyLlava", + "vllm_add_dummy_model.my_llava:MyLlava") diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py new file mode 100644 index 0000000000000..3ebd7864b8fc8 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -0,0 +1,28 @@ +from typing import Optional + +import torch + +from vllm.inputs import INPUT_REGISTRY +from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, + dummy_data_for_llava, + get_max_llava_image_tokens, + input_processor_for_llava) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) +@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) +class MyLlava(LlavaForConditionalGeneration): + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + # this dummy model always predicts the first token + logits = super().compute_logits(hidden_states, sampling_metadata) + if logits is not None: + logits.zero_() + logits[:, 0] += 1.0 + return logits diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py new file mode 100644 index 0000000000000..569ef216c9f0a --- /dev/null +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py @@ -0,0 +1,19 @@ +from typing import Optional + +import torch + +from vllm.model_executor.models.opt import OPTForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata + + +class MyOPTForCausalLM(OPTForCausalLM): + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + # this dummy model always predicts the first token + logits = super().compute_logits(hidden_states, sampling_metadata) + if logits is not None: + logits.zero_() + logits[:, 0] += 1.0 + return logits diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3f0a8d3df8b32..cae95d20ca23d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -183,6 +183,8 @@ class EngineArgs: def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model + from vllm.plugins import load_general_plugins + load_general_plugins() @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d6258c6413d87..adf5d0df72887 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -290,9 +290,6 @@ def __init__( model_config.mm_processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. - from vllm.plugins import load_general_plugins - load_general_plugins() - self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index aa5736e7cd517..a72b9e8909db2 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -125,9 +125,10 @@ **_CONDITIONAL_GENERATION_MODELS, } -# Architecture -> type. +# Architecture -> type or (module, class). # out of tree models _OOT_MODELS: Dict[str, Type[nn.Module]] = {} +_OOT_MODELS_LAZY: Dict[str, Tuple[str, str]] = {} # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS: List[str] = [] @@ -159,17 +160,24 @@ class ModelRegistry: @staticmethod def _get_module_cls_name(model_arch: str) -> Tuple[str, str]: - module_relname, cls_name = _MODELS[model_arch] - return f"vllm.model_executor.models.{module_relname}", cls_name + if model_arch in _MODELS: + module_relname, cls_name = _MODELS[model_arch] + return f"vllm.model_executor.models.{module_relname}", cls_name + + if model_arch in _OOT_MODELS_LAZY: + return _OOT_MODELS_LAZY[model_arch] + + raise KeyError(model_arch) @staticmethod @lru_cache(maxsize=128) def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]: - if model_arch not in _MODELS: + try: + mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + except KeyError: return None - module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) - module = importlib.import_module(module_name) + module = importlib.import_module(mod_name) return getattr(module, cls_name, None) @staticmethod @@ -219,14 +227,35 @@ def get_supported_archs() -> List[str]: return list(_MODELS.keys()) + list(_OOT_MODELS.keys()) @staticmethod - def register_model(model_arch: str, model_cls: Type[nn.Module]): + def register_model(model_arch: str, model_cls: Union[Type[nn.Module], + str]): + """ + Register an external model to be used in vLLM. + + :code:`model_cls` can be either: + + - A :class:`torch.nn.Module` class directly referencing the model. + - A string in the format :code:`:` which can be used to + lazily import the model. This is useful to avoid initializing CUDA + when importing the model and thus the related error + :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`. + """ if model_arch in _MODELS: logger.warning( "Model architecture %s is already registered, and will be " "overwritten by the new model class %s.", model_arch, - model_cls.__name__) + model_cls) + + if isinstance(model_cls, str): + split_str = model_cls.split(":") + if len(split_str) != 2: + msg = "Expected a string in the format `:`" + raise ValueError(msg) - _OOT_MODELS[model_arch] = model_cls + module_name, cls_name = split_str + _OOT_MODELS_LAZY[model_arch] = module_name, cls_name + else: + _OOT_MODELS[model_arch] = model_cls @staticmethod @lru_cache(maxsize=128) @@ -248,13 +277,16 @@ def _check_stateless( if model is not None: return func(model) - if model_arch not in _MODELS and default is not None: - return default + try: + mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + except KeyError: + if default is not None: + return default - module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + raise valid_name_characters = string.ascii_letters + string.digits + "._" - if any(s not in valid_name_characters for s in module_name): + if any(s not in valid_name_characters for s in mod_name): raise ValueError(f"Unsafe module name detected for {model_arch}") if any(s not in valid_name_characters for s in cls_name): raise ValueError(f"Unsafe class name detected for {model_arch}") @@ -266,7 +298,7 @@ def _check_stateless( err_id = uuid.uuid4() stmts = ";".join([ - f"from {module_name} import {cls_name}", + f"from {mod_name} import {cls_name}", f"from {func.__module__} import {func.__name__}", f"assert {func.__name__}({cls_name}), '{err_id}'", ])