diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 65045c4b9c6b4..5c86de08fffda 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -104,7 +104,7 @@ def test_token_match_by_text( or match_str in tokenizer.decode(token_ids, skip_special_tokens=False)): assert match is not None - match_start_idx, match_end_idx = match + match_start_idx, match_end_idx, *_ = match assert match_str in tokenizer.decode( token_ids[match_start_idx:match_end_idx], diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c08db19299adc..ee2f4f0bc7b6c 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from functools import lru_cache -from heapq import nsmallest from itertools import groupby from typing import (Any, Callable, Generic, List, Mapping, NamedTuple, Optional, TypeVar, Union, final) @@ -147,6 +146,9 @@ def _encode( return tokenizer.encode(text, add_special_tokens=add_special_tokens) +_cached_encode = lru_cache(_encode) + + @lru_cache def _max_vocab_token_len(tokenizer: AnyTokenizer) -> int: return max(len(token_text) for token_text in tokenizer.get_vocab()) @@ -157,7 +159,10 @@ class _TokenMatch(NamedTuple): end_idx: int -def find_token_match(token_ids: List[int], match_ids: List[int]): +def find_token_match( + token_ids: List[int], + match_ids: List[int], +) -> Optional[_TokenMatch]: """ Find the first occurrence of :code:`match_ids` in :code:`token_ids`. """ @@ -171,10 +176,27 @@ def find_token_match(token_ids: List[int], match_ids: List[int]): return None -class _Candidate(NamedTuple): +class _TokenMatchFromTextCandidate(NamedTuple): start_idx: int end_idx: int - distance: int + + match_text_prefix: str + match_text_suffix: str + + @property + def distance(self) -> int: + return len(self.match_text_prefix) + len(self.match_text_suffix) + + +class _TokenMatchFromText(NamedTuple): + start_idx: int + end_idx: int + + match_prefix: List[int] + match_suffix: List[int] + + match_text_prefix: str + match_text_suffix: str def find_token_match_by_text( @@ -182,14 +204,21 @@ def find_token_match_by_text( token_ids: List[int], token_text: str, match_text: str, -): +) -> Optional[_TokenMatchFromText]: """ Find the first occurrence of the tokenized :code:`match_text` in :code:`token_ids`. """ - match_ids = _encode(tokenizer, match_text, add_special_tokens=False) + match_ids = _cached_encode(tokenizer, match_text, add_special_tokens=False) if (match := find_token_match(token_ids, match_ids)): - return match + return _TokenMatchFromText( + match.start_idx, + match.end_idx, + match_prefix=[], + match_suffix=[], + match_text_prefix="", + match_text_suffix="", + ) # When `match_text` is not mapped to a special token ID, # it may be tokenized differently based on the surrounding tokens @@ -202,37 +231,42 @@ def find_token_match_by_text( text_end_idx = text_start_idx + len(match_text) # In case the left/right side of `match_text` is fused with the - # string immediately before/after it during tokenization + # string immediately before/after it as a single token text_buffer = _max_vocab_token_len(tokenizer) - 1 left_text = token_text[:max(0, text_start_idx - text_buffer)] right_text = token_text[:text_end_idx + text_buffer] left_idx = len(_encode(tokenizer, left_text, add_special_tokens=False)) right_idx = len(_encode(tokenizer, right_text, add_special_tokens=True)) - avg_idx = (left_idx + right_idx) // 2 window_size = len(match_ids) - valid_candidates = list[_Candidate]() - for start_idx in sorted(range(left_idx, right_idx - window_size + 1), - key=lambda x: abs(x - avg_idx)): + best_distance = len(token_text) + best_candidate = None + + for start_idx in range(left_idx, right_idx - window_size + 1): end_idx = start_idx + window_size candidate_text = tokenizer.decode( token_ids[start_idx:end_idx], + # In case match_text is a special token skip_special_tokens=False, ) if match_text in candidate_text: - candidate = _Candidate( - start_idx=start_idx, - end_idx=end_idx, - distance=len(candidate_text) - len(match_text), + candidate = _TokenMatchFromTextCandidate( + start_idx, + end_idx, + *candidate_text.split(match_text, 1), ) - valid_candidates.append(candidate) + print("candidate", candidate, "best", best_distance) - if candidate.distance == 0: + if candidate.distance < best_distance: + best_candidate = candidate + best_distance = candidate.distance + + if best_distance == 0: break - assert len(valid_candidates) > 0, dict( + assert best_candidate is not None, dict( # To facilitate debugging token_ids=token_ids, match_ids=match_ids, @@ -242,8 +276,25 @@ def find_token_match_by_text( right_idx=right_idx, ) - best_candidate, = nsmallest(1, valid_candidates, key=lambda x: x.distance) - return best_candidate.start_idx, best_candidate.end_idx + match_token_prefix = _cached_encode( + tokenizer, + best_candidate.match_text_prefix, + add_special_tokens=False, + ) + match_token_suffix = _cached_encode( + tokenizer, + best_candidate.match_text_suffix, + add_special_tokens=False, + ) + + return _TokenMatchFromText( + start_idx=best_candidate.start_idx, + end_idx=best_candidate.end_idx, + match_prefix=match_token_prefix, + match_suffix=match_token_suffix, + match_text_prefix=best_candidate.match_text_prefix, + match_text_suffix=best_candidate.match_text_suffix, + ) def apply_placeholders( @@ -253,7 +304,7 @@ def apply_placeholders( match_text: str, replacement_id: int, replacement_count: int, -) -> Optional[PlaceholderRange]: +) -> tuple[List[int], str, Optional[PlaceholderRange]]: """ Find the first occurrence of the tokenized :code:`match_text` in :code:`token_ids`, and replace it with @@ -269,13 +320,25 @@ def apply_placeholders( ) if match is None: - return None + return token_ids, token_text, None + + start_idx, end_idx, prefix_ids, suffix_ids, prefix_str, suffix_str = match + + replacement_ids = (prefix_ids + [replacement_id] * replacement_count + + suffix_ids) + replacement_text = tokenizer.decode( + replacement_ids, + # In case match_text is a special token + skip_special_tokens=False, + ) - # TODO(youkaichao): Don't update new_token_ids - start_idx, end_idx = match - token_ids[start_idx:end_idx] = [replacement_id] * replacement_count + token_ids[start_idx:end_idx] = replacement_ids + token_text = token_text.replace(prefix_str + match_text + suffix_str, + replacement_text, 1) - return PlaceholderRange(offset=start_idx, length=replacement_count) + return (token_ids, token_text, + PlaceholderRange(offset=start_idx + len(prefix_ids), + length=replacement_count)) class MultiModalProcessor: @@ -349,7 +412,11 @@ def apply( item_idx, ) - placeholders = apply_placeholders( + ( + new_token_ids, + prompt, + placeholders, + ) = apply_placeholders( tokenizer, new_token_ids, prompt,