From 4861d54c7c641aeba2bed0eaaf40999b2a9e4a0c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 21 Nov 2024 03:32:37 +0000 Subject: [PATCH] Provide necessary data for replacement Signed-off-by: DarkLight1337 --- tests/multimodal/test_processing.py | 2 +- vllm/multimodal/inputs.py | 9 +- vllm/multimodal/processing.py | 132 +++++++++++++++++++++------- 3 files changed, 102 insertions(+), 41 deletions(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 65045c4b9c6b4..5c86de08fffda 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -104,7 +104,7 @@ def test_token_match_by_text( or match_str in tokenizer.decode(token_ids, skip_special_tokens=False)): assert match is not None - match_start_idx, match_end_idx = match + match_start_idx, match_end_idx, *_ = match assert match_str in tokenizer.decode( token_ids[match_start_idx:match_end_idx], diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 64a4c58d5509c..8e67a552afe12 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -203,14 +203,7 @@ class MultiModalInputsV2(TypedDict): """The type of inputs.""" prompt: str - """ - The original, unprocessed prompt text. - - Note: - Since prompt text is not required by vLLM internals, we leave this - unprocessed to save CPU computation. You can still call - :code:`tokenizer.decode(prompt_token_ids)` to get the processed text. - """ + """The processed prompt text.""" prompt_token_ids: List[int] """The processed token IDs which includes placeholder tokens.""" diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c08db19299adc..57a9b6ed8b113 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from functools import lru_cache -from heapq import nsmallest from itertools import groupby from typing import (Any, Callable, Generic, List, Mapping, NamedTuple, Optional, TypeVar, Union, final) @@ -147,6 +146,9 @@ def _encode( return tokenizer.encode(text, add_special_tokens=add_special_tokens) +_cached_encode = lru_cache(_encode) + + @lru_cache def _max_vocab_token_len(tokenizer: AnyTokenizer) -> int: return max(len(token_text) for token_text in tokenizer.get_vocab()) @@ -157,7 +159,10 @@ class _TokenMatch(NamedTuple): end_idx: int -def find_token_match(token_ids: List[int], match_ids: List[int]): +def find_token_match( + token_ids: List[int], + match_ids: List[int], +) -> Optional[_TokenMatch]: """ Find the first occurrence of :code:`match_ids` in :code:`token_ids`. """ @@ -171,10 +176,27 @@ def find_token_match(token_ids: List[int], match_ids: List[int]): return None -class _Candidate(NamedTuple): +class _TokenMatchFromTextCandidate(NamedTuple): start_idx: int end_idx: int - distance: int + + match_text_prefix: str + match_text_suffix: str + + @property + def distance(self) -> int: + return len(self.match_text_prefix) + len(self.match_text_suffix) + + +class _TokenMatchFromText(NamedTuple): + start_idx: int + end_idx: int + + match_prefix: List[int] + match_suffix: List[int] + + match_text_prefix: str + match_text_suffix: str def find_token_match_by_text( @@ -182,14 +204,21 @@ def find_token_match_by_text( token_ids: List[int], token_text: str, match_text: str, -): +) -> Optional[_TokenMatchFromText]: """ Find the first occurrence of the tokenized :code:`match_text` in :code:`token_ids`. """ - match_ids = _encode(tokenizer, match_text, add_special_tokens=False) + match_ids = _cached_encode(tokenizer, match_text, add_special_tokens=False) if (match := find_token_match(token_ids, match_ids)): - return match + return _TokenMatchFromText( + match.start_idx, + match.end_idx, + match_prefix=[], + match_suffix=[], + match_text_prefix="", + match_text_suffix="", + ) # When `match_text` is not mapped to a special token ID, # it may be tokenized differently based on the surrounding tokens @@ -202,37 +231,41 @@ def find_token_match_by_text( text_end_idx = text_start_idx + len(match_text) # In case the left/right side of `match_text` is fused with the - # string immediately before/after it during tokenization + # string immediately before/after it as a single token text_buffer = _max_vocab_token_len(tokenizer) - 1 left_text = token_text[:max(0, text_start_idx - text_buffer)] right_text = token_text[:text_end_idx + text_buffer] left_idx = len(_encode(tokenizer, left_text, add_special_tokens=False)) right_idx = len(_encode(tokenizer, right_text, add_special_tokens=True)) - avg_idx = (left_idx + right_idx) // 2 window_size = len(match_ids) - valid_candidates = list[_Candidate]() - for start_idx in sorted(range(left_idx, right_idx - window_size + 1), - key=lambda x: abs(x - avg_idx)): + best_distance = len(token_text) + best_candidate = None + + for start_idx in range(left_idx, right_idx - window_size + 1): end_idx = start_idx + window_size candidate_text = tokenizer.decode( token_ids[start_idx:end_idx], + # In case match_text is a special token skip_special_tokens=False, ) if match_text in candidate_text: - candidate = _Candidate( - start_idx=start_idx, - end_idx=end_idx, - distance=len(candidate_text) - len(match_text), + candidate = _TokenMatchFromTextCandidate( + start_idx, + end_idx, + *candidate_text.split(match_text, 1), ) - valid_candidates.append(candidate) - if candidate.distance == 0: + if candidate.distance < best_distance: + best_candidate = candidate + best_distance = candidate.distance + + if best_distance == 0: break - assert len(valid_candidates) > 0, dict( + assert best_candidate is not None, dict( # To facilitate debugging token_ids=token_ids, match_ids=match_ids, @@ -242,8 +275,25 @@ def find_token_match_by_text( right_idx=right_idx, ) - best_candidate, = nsmallest(1, valid_candidates, key=lambda x: x.distance) - return best_candidate.start_idx, best_candidate.end_idx + match_token_prefix = _cached_encode( + tokenizer, + best_candidate.match_text_prefix, + add_special_tokens=False, + ) + match_token_suffix = _cached_encode( + tokenizer, + best_candidate.match_text_suffix, + add_special_tokens=False, + ) + + return _TokenMatchFromText( + start_idx=best_candidate.start_idx, + end_idx=best_candidate.end_idx, + match_prefix=match_token_prefix, + match_suffix=match_token_suffix, + match_text_prefix=best_candidate.match_text_prefix, + match_text_suffix=best_candidate.match_text_suffix, + ) def apply_placeholders( @@ -253,7 +303,7 @@ def apply_placeholders( match_text: str, replacement_id: int, replacement_count: int, -) -> Optional[PlaceholderRange]: +) -> tuple[List[int], str, Optional[PlaceholderRange]]: """ Find the first occurrence of the tokenized :code:`match_text` in :code:`token_ids`, and replace it with @@ -269,13 +319,25 @@ def apply_placeholders( ) if match is None: - return None + return token_ids, token_text, None + + start_idx, end_idx, prefix_ids, suffix_ids, prefix_str, suffix_str = match - # TODO(youkaichao): Don't update new_token_ids - start_idx, end_idx = match - token_ids[start_idx:end_idx] = [replacement_id] * replacement_count + replacement_ids = (prefix_ids + [replacement_id] * replacement_count + + suffix_ids) + replacement_text = tokenizer.decode( + replacement_ids, + # In case match_text is a special token + skip_special_tokens=False, + ) + + token_ids[start_idx:end_idx] = replacement_ids + token_text = token_text.replace(prefix_str + match_text + suffix_str, + replacement_text, 1) - return PlaceholderRange(offset=start_idx, length=replacement_count) + return (token_ids, token_text, + PlaceholderRange(offset=start_idx + len(prefix_ids), + length=replacement_count)) class MultiModalProcessor: @@ -318,6 +380,7 @@ def apply( new_token_ids, = processed_inputs.pop("input_ids").tolist() mm_kwargs = MultiModalKwargs(processed_inputs) + new_prompt = prompt mm_placeholders: Mapping[str, List[PlaceholderRange]] = {} for modality, orig_inputs in to_multi_format(mm_data).items(): @@ -337,8 +400,9 @@ def apply( if new_token_id in repl_token_ids: modality_placeholders.append(run_info) - # Otherwise, we insert them ourselves - if not modality_placeholders: + if modality_placeholders: + new_prompt = tokenizer.decode(new_token_ids) + else: # Otherwise, we insert them ourselves for item_idx, orig_item in enumerate(orig_inputs): for match_str, replacement in placeholder_repls.items(): replacement_count = replacement["count"] @@ -349,10 +413,14 @@ def apply( item_idx, ) - placeholders = apply_placeholders( + ( + new_token_ids, + new_prompt, + placeholders, + ) = apply_placeholders( tokenizer, new_token_ids, - prompt, + new_prompt, match_str, replacement["token_id"], replacement_count, @@ -365,7 +433,7 @@ def apply( return MultiModalInputsV2( type="multimodal", - prompt=prompt, + prompt=new_prompt, prompt_token_ids=new_token_ids, mm_kwargs=mm_kwargs, mm_placeholders=mm_placeholders,