From e338df1ae6f4e7438c05e7d8d03c3e649d21fdb7 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 20 Sep 2024 15:20:20 +0800 Subject: [PATCH 1/8] add checker for components weight loading --- vllm/model_executor/models/llava.py | 5 +++-- vllm/model_executor/models/utils.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7a6c991fb133a..14b6fe52e8053 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -26,7 +26,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, +from .utils import (check_filter_available, filter_weights, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) @@ -393,7 +393,8 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + vit_weights, mlp_weights, llm_weights, tracker = itertools.tee(weights, 4) + check_filter_available(tracker, ["vision_tower", "multi_modal_projector", "language_model"]) # load vision encoder vit_weights = filter_weights(vit_weights, "vision_tower") diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 8b80dda96db49..a60d9ae6a82e7 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -30,6 +30,20 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): yield name, loaded_weight +def check_filter_available(weights: Iterable[Tuple[str, torch.Tensor]], prefix: List[str]): + """ + Helper function to check if the filter is available for the loaded weights + """ + unexpected_name = set() + for name, _ in weights: + weight_prefix = name.split(".")[0] + if weight_prefix not in prefix: + unexpected_name.add(name) + + if unexpected_name: + raise ValueError(f"Loaded weights contents unexpected weight: {unexpected_name}") + + def init_vllm_registered_model( hf_config: PretrainedConfig, cache_config: Optional[CacheConfig], From f54562c4ae13a3face5a12cedafe4a8995f79d3b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 21 Sep 2024 13:41:01 +0800 Subject: [PATCH 2/8] refactor component weight loading --- vllm/model_executor/models/llava.py | 17 ++++++----------- vllm/model_executor/models/utils.py | 29 ++++++++++++++++++----------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 14b6fe52e8053..69eb177a7dea8 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,3 @@ -import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -26,8 +25,8 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) -from .utils import (check_filter_available, filter_weights, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (flatten_bn, group_weights_with_prefix, + init_vllm_registered_model, merge_multimodal_embeddings) class LlavaImagePixelInputs(TypedDict): @@ -393,22 +392,18 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, llm_weights, tracker = itertools.tee(weights, 4) - check_filter_available(tracker, ["vision_tower", "multi_modal_projector", "language_model"]) + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index a60d9ae6a82e7..4ddd206d8247f 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,3 +1,4 @@ +import itertools from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, Union, overload) @@ -16,7 +17,8 @@ from vllm.utils import is_pin_memory_available -def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): +def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], + prefix: str) -> Iterable[Tuple[str, torch.Tensor]]: """ Helper function to load weights for inner vLLM models. @@ -30,18 +32,23 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): yield name, loaded_weight -def check_filter_available(weights: Iterable[Tuple[str, torch.Tensor]], prefix: List[str]): +def group_weights_with_prefix( + weights: Iterable[Tuple[str, torch.Tensor]] +) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]: """ - Helper function to check if the filter is available for the loaded weights + Helper function to group weights with prefix """ - unexpected_name = set() - for name, _ in weights: - weight_prefix = name.split(".")[0] - if weight_prefix not in prefix: - unexpected_name.add(name) - - if unexpected_name: - raise ValueError(f"Loaded weights contents unexpected weight: {unexpected_name}") + init_weights, repeated_weights = itertools.tee(weights, 2) + + prefix = set(map(lambda x: x[0].split(".")[0], init_weights)) + + repeated_weights = itertools.tee(repeated_weights, len(prefix)) + + grouped_weights = {} + for weight_component, weight_prefix in zip(repeated_weights, prefix): + grouped_weights[weight_prefix] = filter_weights( + weight_component, weight_prefix) + return grouped_weights def init_vllm_registered_model( From 65d1cedaef1adf7235dfa600ab74f5daa85b3081 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 21 Sep 2024 13:46:16 +0800 Subject: [PATCH 3/8] modify internvl --- vllm/model_executor/models/internvl.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 507d7014714a2..005a24f10aa17 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -4,7 +4,6 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- -import itertools import re from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -33,8 +32,8 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) from .interfaces import SupportsMultiModal -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (flatten_bn, group_weights_with_prefix, + init_vllm_registered_model, merge_multimodal_embeddings) IMG_START = '' IMG_END = '' @@ -518,21 +517,18 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_model") - self.vision_model.load_weights(vit_weights) + self.vision_model.load_weights(weights_group["vision_model"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "mlp1") mlp_params_dict = dict(self.mlp1.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["mlp1"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) From c56ff8f4d20d185da49bfb47b657cc1dba98588f Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 21 Sep 2024 13:58:20 +0800 Subject: [PATCH 4/8] modify llava-next --- vllm/model_executor/models/llava_next.py | 20 +++++++------------ .../model_executor/models/llava_next_video.py | 17 ++++++---------- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c6bd46dd7eda9..bbc43fda167c5 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,4 +1,3 @@ -import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -30,8 +29,8 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (flatten_bn, group_weights_with_prefix, + init_vllm_registered_model, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -635,25 +634,21 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( - weights, 4) + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load newline - newline_weights = filter_weights(newline_weights, "image_newline") - for name, loaded_weight in newline_weights: + for name, loaded_weight in weights_group["image_newline"]: assert name == "" param = self.image_newline weight_loader = getattr(param, "weight_loader", @@ -661,5 +656,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 7fe85e5e4ab3d..a8b5176dc43cf 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -1,4 +1,3 @@ -import itertools import math from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -30,7 +29,7 @@ from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip) -from .utils import (filter_weights, init_vllm_registered_model, +from .utils import (group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -449,23 +448,19 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators - vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( - weights, 4) + # prepare weight iterators for components + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) From 25a57b1bf52c7db39c38509d97ca2ea2e25b34d5 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 21 Sep 2024 14:13:56 +0800 Subject: [PATCH 5/8] clean up group_weights_with_prefix --- vllm/model_executor/models/utils.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 4ddd206d8247f..355ef18836005 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -39,15 +39,12 @@ def group_weights_with_prefix( Helper function to group weights with prefix """ init_weights, repeated_weights = itertools.tee(weights, 2) - - prefix = set(map(lambda x: x[0].split(".")[0], init_weights)) - - repeated_weights = itertools.tee(repeated_weights, len(prefix)) + weights_prefix = set(map(lambda x: x[0].split(".")[0], init_weights)) + repeated_weights = itertools.tee(repeated_weights, len(weights_prefix)) grouped_weights = {} - for weight_component, weight_prefix in zip(repeated_weights, prefix): - grouped_weights[weight_prefix] = filter_weights( - weight_component, weight_prefix) + for component, prefix in zip(repeated_weights, weights_prefix): + grouped_weights[prefix] = filter_weights(component, prefix) return grouped_weights From 019bb4953340f0f303e16a5951e910ae7941798e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 21 Sep 2024 14:17:46 +0800 Subject: [PATCH 6/8] modify paligemma --- vllm/model_executor/models/paligemma.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 5fd39b5e35be6..68b6d0cf808e1 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,4 +1,3 @@ -import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -23,7 +22,7 @@ from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) -from .utils import filter_weights, merge_multimodal_embeddings +from .utils import group_weights_with_prefix, merge_multimodal_embeddings logger = init_logger(__name__) @@ -286,21 +285,18 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + weights_group = group_weights_with_prefix(weights) # load vision tower - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) From 8371dec4d4eb4a57fa53d35629866118a1857215 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 21 Sep 2024 18:49:08 +0800 Subject: [PATCH 7/8] modify ultravox --- vllm/model_executor/models/ultravox.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 416fabda831a2..b60e9a03683e9 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -1,7 +1,6 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" -import itertools import math from array import array from functools import lru_cache @@ -29,7 +28,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.utils import (filter_weights, flatten_bn, +from vllm.model_executor.models.utils import (flatten_bn, + group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -453,11 +453,10 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - projector_weights, llm_weights = itertools.tee(weights, 2) + weights_group = group_weights_with_prefix(weights) # load projector weights - projector_weights = filter_weights(projector_weights, - "multi_modal_projector") + projector_weights = weights_group["multi_modal_projector"] projector_params_dict = dict( self.multi_modal_projector.named_parameters()) for name, loaded_weight in projector_weights: @@ -467,5 +466,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) From 036c9252734861156d16fb1a6ceb823fd165eb03 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 21 Sep 2024 21:16:54 +0800 Subject: [PATCH 8/8] add WeightsGroup and clean up --- vllm/model_executor/models/utils.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 355ef18836005..38d6a4653ebd6 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,5 @@ import itertools +from collections import UserDict from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, Union, overload) @@ -17,6 +18,21 @@ from vllm.utils import is_pin_memory_available +class WeightsGroup(UserDict): + """ + Wraps grouped weights dictionary for a more informative error message + when attempting to access a weight component that does not exist. + """ + + def __getitem__(self, key: str) -> int: + try: + return super().__getitem__(key) + except KeyError as exc: + msg = (f"There is no weights named with the prefix: {key}. " + f"Available prefix: {set(self.keys())}") + raise KeyError(msg) from exc + + def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str) -> Iterable[Tuple[str, torch.Tensor]]: """ @@ -39,13 +55,13 @@ def group_weights_with_prefix( Helper function to group weights with prefix """ init_weights, repeated_weights = itertools.tee(weights, 2) - weights_prefix = set(map(lambda x: x[0].split(".")[0], init_weights)) + weights_prefix = {name.split(".")[0] for name, _ in init_weights} repeated_weights = itertools.tee(repeated_weights, len(weights_prefix)) - grouped_weights = {} - for component, prefix in zip(repeated_weights, weights_prefix): - grouped_weights[prefix] = filter_weights(component, prefix) - return grouped_weights + return WeightsGroup({ + prefix: filter_weights(component, prefix) + for component, prefix in zip(repeated_weights, weights_prefix) + }) def init_vllm_registered_model(