Skip to content

Commit

Permalink
[Core][Model] Support loading weights by ID within models (vllm-proje…
Browse files Browse the repository at this point in the history
  • Loading branch information
petersalas authored Sep 24, 2024
1 parent b8747e8 commit 3f06bae
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 17 deletions.
60 changes: 47 additions & 13 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# ruff: noqa: SIM117
import collections
import copy
import dataclasses
import fnmatch
import glob
import json
import math
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple,
Type, cast)

import gguf
import huggingface_hub
Expand Down Expand Up @@ -207,6 +209,22 @@ def load_model(self, *, model_config: ModelConfig,
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""

@dataclasses.dataclass
class Source:
"""A source for weights."""

model_or_path: str
"""The model ID or path."""

revision: Optional[str]
"""The optional model revision."""

prefix: str = ""
"""A prefix to prepend to all weights."""

fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
Expand Down Expand Up @@ -313,17 +331,16 @@ def _prepare_weights(self, model_name_or_path: str,
return hf_folder, hf_weights_files, use_safetensors

def _get_weights_iterator(
self, model_name_or_path: str, revision: Optional[str],
fall_back_to_pt: bool
self, source: "Source"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision, fall_back_to_pt)
source.model_or_path, source.revision, source.fall_back_to_pt)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
model_name_or_path, self.load_config.download_dir, hf_folder,
source.model_or_path, self.load_config.download_dir, hf_folder,
hf_weights_files)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files)
Expand All @@ -341,7 +358,29 @@ def _xla_weights_iterator(iterator: Generator):
xm.mark_step()

weights_iterator = _xla_weights_iterator(weights_iterator)
return weights_iterator

# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)

def _get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]:

primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True))
yield from self._get_weights_iterator(primary_weights)

secondary_weights = cast(Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()))
for source in secondary_weights:
yield from self._get_weights_iterator(source)

def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
Expand All @@ -360,13 +399,8 @@ def load_model(self, *, model_config: ModelConfig,
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
fall_back_to_pt=getattr(
model,
"fall_back_to_pt_during_load",
True)), )

model.load_weights(self._get_all_weights(model_config, model))

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
Expand Down
30 changes: 26 additions & 4 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (flatten_bn,
Expand Down Expand Up @@ -334,14 +335,23 @@ def __init__(self,
self.multi_modal_config = multimodal_config
assert self.multi_modal_config

self.secondary_weights = []
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.audio_model_id is not None:
self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id)
else:
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
self.secondary_weights.append(
DefaultModelLoader.Source(
model_or_path=config.audio_model_id,
revision=None,
prefix="audio_tower.",
))
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
if config.text_model_id is not None:
self.secondary_weights.append(
DefaultModelLoader.Source(model_or_path=config.text_model_id,
revision=None,
prefix="language_model."))

def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -466,6 +476,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)

# load audio tower weights
audio_tower_weights = weights_group["audio_tower"]
audio_tower_params_dict = dict(
self.audio_tower.named_parameters(
prefix=self.audio_tower.base_model_prefix))
for name, loaded_weight in audio_tower_weights:
if name in audio_tower_params_dict:
param = audio_tower_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load projector weights
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(
Expand Down

0 comments on commit 3f06bae

Please sign in to comment.