Skip to content

Commit

Permalink
[BugFix] Prevent LLM.encode for non-generation Models (vllm-project…
Browse files Browse the repository at this point in the history
…#5184)

Co-authored-by: mgoin <michael@neuralmagic.com>
  • Loading branch information
2 people authored and jimpang committed Jun 27, 2024
1 parent 8615094 commit 01cc533
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,11 @@ def generate(
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.generate() is only supported for generation models "
"(XForCausalLM).")

if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
Expand Down Expand Up @@ -420,6 +425,11 @@ def encode(
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if not self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.encode() is only supported for embedding models (XModel)."
)

if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
Expand Down

0 comments on commit 01cc533

Please sign in to comment.