Skip to content

Commit

Permalink
Provide necessary data for replacement
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
DarkLight1337 committed Nov 21, 2024
1 parent 4b23817 commit 9a47b6e
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 29 deletions.
2 changes: 1 addition & 1 deletion tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_token_match_by_text(
or match_str in tokenizer.decode(token_ids,
skip_special_tokens=False)):
assert match is not None
match_start_idx, match_end_idx = match
match_start_idx, match_end_idx, *_ = match

assert match_str in tokenizer.decode(
token_ids[match_start_idx:match_end_idx],
Expand Down
123 changes: 95 additions & 28 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass
from functools import lru_cache
from heapq import nsmallest
from itertools import groupby
from typing import (Any, Callable, Generic, List, Mapping, NamedTuple,
Optional, TypeVar, Union, final)
Expand Down Expand Up @@ -147,6 +146,9 @@ def _encode(
return tokenizer.encode(text, add_special_tokens=add_special_tokens)


_cached_encode = lru_cache(_encode)


@lru_cache
def _max_vocab_token_len(tokenizer: AnyTokenizer) -> int:
return max(len(token_text) for token_text in tokenizer.get_vocab())
Expand All @@ -157,7 +159,10 @@ class _TokenMatch(NamedTuple):
end_idx: int


def find_token_match(token_ids: List[int], match_ids: List[int]):
def find_token_match(
token_ids: List[int],
match_ids: List[int],
) -> Optional[_TokenMatch]:
"""
Find the first occurrence of :code:`match_ids` in :code:`token_ids`.
"""
Expand All @@ -171,25 +176,49 @@ def find_token_match(token_ids: List[int], match_ids: List[int]):
return None


class _Candidate(NamedTuple):
class _TokenMatchFromTextCandidate(NamedTuple):
start_idx: int
end_idx: int
distance: int

match_text_prefix: str
match_text_suffix: str

@property
def distance(self) -> int:
return len(self.match_text_prefix) + len(self.match_text_suffix)


class _TokenMatchFromText(NamedTuple):
start_idx: int
end_idx: int

match_prefix: List[int]
match_suffix: List[int]

match_text_prefix: str
match_text_suffix: str


def find_token_match_by_text(
tokenizer: AnyTokenizer,
token_ids: List[int],
token_text: str,
match_text: str,
):
) -> Optional[_TokenMatchFromText]:
"""
Find the first occurrence of the tokenized :code:`match_text` in
:code:`token_ids`.
"""
match_ids = _encode(tokenizer, match_text, add_special_tokens=False)
match_ids = _cached_encode(tokenizer, match_text, add_special_tokens=False)
if (match := find_token_match(token_ids, match_ids)):
return match
return _TokenMatchFromText(
match.start_idx,
match.end_idx,
match_prefix=[],
match_suffix=[],
match_text_prefix="",
match_text_suffix="",
)

# When `match_text` is not mapped to a special token ID,
# it may be tokenized differently based on the surrounding tokens
Expand All @@ -202,37 +231,42 @@ def find_token_match_by_text(
text_end_idx = text_start_idx + len(match_text)

# In case the left/right side of `match_text` is fused with the
# string immediately before/after it during tokenization
# string immediately before/after it as a single token
text_buffer = _max_vocab_token_len(tokenizer) - 1
left_text = token_text[:max(0, text_start_idx - text_buffer)]
right_text = token_text[:text_end_idx + text_buffer]

left_idx = len(_encode(tokenizer, left_text, add_special_tokens=False))
right_idx = len(_encode(tokenizer, right_text, add_special_tokens=True))
avg_idx = (left_idx + right_idx) // 2
window_size = len(match_ids)

valid_candidates = list[_Candidate]()
for start_idx in sorted(range(left_idx, right_idx - window_size + 1),
key=lambda x: abs(x - avg_idx)):
best_distance = len(token_text)
best_candidate = None

for start_idx in range(left_idx, right_idx - window_size + 1):
end_idx = start_idx + window_size
candidate_text = tokenizer.decode(
token_ids[start_idx:end_idx],
# In case match_text is a special token
skip_special_tokens=False,
)

if match_text in candidate_text:
candidate = _Candidate(
start_idx=start_idx,
end_idx=end_idx,
distance=len(candidate_text) - len(match_text),
candidate = _TokenMatchFromTextCandidate(
start_idx,
end_idx,
*candidate_text.split(match_text, 1),
)
valid_candidates.append(candidate)
print("candidate", candidate, "best", best_distance)

if candidate.distance == 0:
if candidate.distance < best_distance:
best_candidate = candidate
best_distance = candidate.distance

if best_distance == 0:
break

assert len(valid_candidates) > 0, dict(
assert best_candidate is not None, dict(
# To facilitate debugging
token_ids=token_ids,
match_ids=match_ids,
Expand All @@ -242,8 +276,25 @@ def find_token_match_by_text(
right_idx=right_idx,
)

best_candidate, = nsmallest(1, valid_candidates, key=lambda x: x.distance)
return best_candidate.start_idx, best_candidate.end_idx
match_token_prefix = _cached_encode(
tokenizer,
best_candidate.match_text_prefix,
add_special_tokens=False,
)
match_token_suffix = _cached_encode(
tokenizer,
best_candidate.match_text_suffix,
add_special_tokens=False,
)

return _TokenMatchFromText(
start_idx=best_candidate.start_idx,
end_idx=best_candidate.end_idx,
match_prefix=match_token_prefix,
match_suffix=match_token_suffix,
match_text_prefix=best_candidate.match_text_prefix,
match_text_suffix=best_candidate.match_text_suffix,
)


def apply_placeholders(
Expand All @@ -253,7 +304,7 @@ def apply_placeholders(
match_text: str,
replacement_id: int,
replacement_count: int,
) -> Optional[PlaceholderRange]:
) -> tuple[List[int], str, Optional[PlaceholderRange]]:
"""
Find the first occurrence of the tokenized :code:`match_text` in
:code:`token_ids`, and replace it with
Expand All @@ -269,13 +320,25 @@ def apply_placeholders(
)

if match is None:
return None
return token_ids, token_text, None

start_idx, end_idx, prefix_ids, suffix_ids, prefix_str, suffix_str = match

replacement_ids = (prefix_ids + [replacement_id] * replacement_count +
suffix_ids)
replacement_text = tokenizer.decode(
replacement_ids,
# In case match_text is a special token
skip_special_tokens=False,
)

# TODO(youkaichao): Don't update new_token_ids
start_idx, end_idx = match
token_ids[start_idx:end_idx] = [replacement_id] * replacement_count
token_ids[start_idx:end_idx] = replacement_ids
token_text = token_text.replace(prefix_str + match_text + suffix_str,
replacement_text, 1)

return PlaceholderRange(offset=start_idx, length=replacement_count)
return (token_ids, token_text,
PlaceholderRange(offset=start_idx + len(prefix_ids),
length=replacement_count))


class MultiModalProcessor:
Expand Down Expand Up @@ -349,7 +412,11 @@ def apply(
item_idx,
)

placeholders = apply_placeholders(
(
new_token_ids,
prompt,
placeholders,
) = apply_placeholders(
tokenizer,
new_token_ids,
prompt,
Expand Down

0 comments on commit 9a47b6e

Please sign in to comment.