diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d22ea6b79de0f..b6ea6ab396642 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -28,6 +28,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from vllm.utils import print_warning_once class QWenMLP(nn.Module): @@ -288,6 +289,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip loading visual weights to support Qwen-VL models + # in cases with text-only inputs + # TODO: add support for Qwen-VL + if (name not in params_dict + and name.startswith("transformer.visual.")): + print_warning_once( + "Only text inputs are allowed. Images won't be handled " + "until Qwen-VL models are fully supported.") + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)