Skip to content

Commit

Permalink
Provide necessary data for replacement
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
DarkLight1337 committed Nov 21, 2024
1 parent 4b23817 commit 4861d54
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 41 deletions.
2 changes: 1 addition & 1 deletion tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_token_match_by_text(
or match_str in tokenizer.decode(token_ids,
skip_special_tokens=False)):
assert match is not None
match_start_idx, match_end_idx = match
match_start_idx, match_end_idx, *_ = match

assert match_str in tokenizer.decode(
token_ids[match_start_idx:match_end_idx],
Expand Down
9 changes: 1 addition & 8 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,7 @@ class MultiModalInputsV2(TypedDict):
"""The type of inputs."""

prompt: str
"""
The original, unprocessed prompt text.
Note:
Since prompt text is not required by vLLM internals, we leave this
unprocessed to save CPU computation. You can still call
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
"""
"""The processed prompt text."""

prompt_token_ids: List[int]
"""The processed token IDs which includes placeholder tokens."""
Expand Down
132 changes: 100 additions & 32 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass
from functools import lru_cache
from heapq import nsmallest
from itertools import groupby
from typing import (Any, Callable, Generic, List, Mapping, NamedTuple,
Optional, TypeVar, Union, final)
Expand Down Expand Up @@ -147,6 +146,9 @@ def _encode(
return tokenizer.encode(text, add_special_tokens=add_special_tokens)


_cached_encode = lru_cache(_encode)


@lru_cache
def _max_vocab_token_len(tokenizer: AnyTokenizer) -> int:
return max(len(token_text) for token_text in tokenizer.get_vocab())
Expand All @@ -157,7 +159,10 @@ class _TokenMatch(NamedTuple):
end_idx: int


def find_token_match(token_ids: List[int], match_ids: List[int]):
def find_token_match(
token_ids: List[int],
match_ids: List[int],
) -> Optional[_TokenMatch]:
"""
Find the first occurrence of :code:`match_ids` in :code:`token_ids`.
"""
Expand All @@ -171,25 +176,49 @@ def find_token_match(token_ids: List[int], match_ids: List[int]):
return None


class _Candidate(NamedTuple):
class _TokenMatchFromTextCandidate(NamedTuple):
start_idx: int
end_idx: int
distance: int

match_text_prefix: str
match_text_suffix: str

@property
def distance(self) -> int:
return len(self.match_text_prefix) + len(self.match_text_suffix)


class _TokenMatchFromText(NamedTuple):
start_idx: int
end_idx: int

match_prefix: List[int]
match_suffix: List[int]

match_text_prefix: str
match_text_suffix: str


def find_token_match_by_text(
tokenizer: AnyTokenizer,
token_ids: List[int],
token_text: str,
match_text: str,
):
) -> Optional[_TokenMatchFromText]:
"""
Find the first occurrence of the tokenized :code:`match_text` in
:code:`token_ids`.
"""
match_ids = _encode(tokenizer, match_text, add_special_tokens=False)
match_ids = _cached_encode(tokenizer, match_text, add_special_tokens=False)
if (match := find_token_match(token_ids, match_ids)):
return match
return _TokenMatchFromText(
match.start_idx,
match.end_idx,
match_prefix=[],
match_suffix=[],
match_text_prefix="",
match_text_suffix="",
)

# When `match_text` is not mapped to a special token ID,
# it may be tokenized differently based on the surrounding tokens
Expand All @@ -202,37 +231,41 @@ def find_token_match_by_text(
text_end_idx = text_start_idx + len(match_text)

# In case the left/right side of `match_text` is fused with the
# string immediately before/after it during tokenization
# string immediately before/after it as a single token
text_buffer = _max_vocab_token_len(tokenizer) - 1
left_text = token_text[:max(0, text_start_idx - text_buffer)]
right_text = token_text[:text_end_idx + text_buffer]

left_idx = len(_encode(tokenizer, left_text, add_special_tokens=False))
right_idx = len(_encode(tokenizer, right_text, add_special_tokens=True))
avg_idx = (left_idx + right_idx) // 2
window_size = len(match_ids)

valid_candidates = list[_Candidate]()
for start_idx in sorted(range(left_idx, right_idx - window_size + 1),
key=lambda x: abs(x - avg_idx)):
best_distance = len(token_text)
best_candidate = None

for start_idx in range(left_idx, right_idx - window_size + 1):
end_idx = start_idx + window_size
candidate_text = tokenizer.decode(
token_ids[start_idx:end_idx],
# In case match_text is a special token
skip_special_tokens=False,
)

if match_text in candidate_text:
candidate = _Candidate(
start_idx=start_idx,
end_idx=end_idx,
distance=len(candidate_text) - len(match_text),
candidate = _TokenMatchFromTextCandidate(
start_idx,
end_idx,
*candidate_text.split(match_text, 1),
)
valid_candidates.append(candidate)

if candidate.distance == 0:
if candidate.distance < best_distance:
best_candidate = candidate
best_distance = candidate.distance

if best_distance == 0:
break

assert len(valid_candidates) > 0, dict(
assert best_candidate is not None, dict(
# To facilitate debugging
token_ids=token_ids,
match_ids=match_ids,
Expand All @@ -242,8 +275,25 @@ def find_token_match_by_text(
right_idx=right_idx,
)

best_candidate, = nsmallest(1, valid_candidates, key=lambda x: x.distance)
return best_candidate.start_idx, best_candidate.end_idx
match_token_prefix = _cached_encode(
tokenizer,
best_candidate.match_text_prefix,
add_special_tokens=False,
)
match_token_suffix = _cached_encode(
tokenizer,
best_candidate.match_text_suffix,
add_special_tokens=False,
)

return _TokenMatchFromText(
start_idx=best_candidate.start_idx,
end_idx=best_candidate.end_idx,
match_prefix=match_token_prefix,
match_suffix=match_token_suffix,
match_text_prefix=best_candidate.match_text_prefix,
match_text_suffix=best_candidate.match_text_suffix,
)


def apply_placeholders(
Expand All @@ -253,7 +303,7 @@ def apply_placeholders(
match_text: str,
replacement_id: int,
replacement_count: int,
) -> Optional[PlaceholderRange]:
) -> tuple[List[int], str, Optional[PlaceholderRange]]:
"""
Find the first occurrence of the tokenized :code:`match_text` in
:code:`token_ids`, and replace it with
Expand All @@ -269,13 +319,25 @@ def apply_placeholders(
)

if match is None:
return None
return token_ids, token_text, None

start_idx, end_idx, prefix_ids, suffix_ids, prefix_str, suffix_str = match

# TODO(youkaichao): Don't update new_token_ids
start_idx, end_idx = match
token_ids[start_idx:end_idx] = [replacement_id] * replacement_count
replacement_ids = (prefix_ids + [replacement_id] * replacement_count +
suffix_ids)
replacement_text = tokenizer.decode(
replacement_ids,
# In case match_text is a special token
skip_special_tokens=False,
)

token_ids[start_idx:end_idx] = replacement_ids
token_text = token_text.replace(prefix_str + match_text + suffix_str,
replacement_text, 1)

return PlaceholderRange(offset=start_idx, length=replacement_count)
return (token_ids, token_text,
PlaceholderRange(offset=start_idx + len(prefix_ids),
length=replacement_count))


class MultiModalProcessor:
Expand Down Expand Up @@ -318,6 +380,7 @@ def apply(
new_token_ids, = processed_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(processed_inputs)

new_prompt = prompt
mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}

for modality, orig_inputs in to_multi_format(mm_data).items():
Expand All @@ -337,8 +400,9 @@ def apply(
if new_token_id in repl_token_ids:
modality_placeholders.append(run_info)

# Otherwise, we insert them ourselves
if not modality_placeholders:
if modality_placeholders:
new_prompt = tokenizer.decode(new_token_ids)
else: # Otherwise, we insert them ourselves
for item_idx, orig_item in enumerate(orig_inputs):
for match_str, replacement in placeholder_repls.items():
replacement_count = replacement["count"]
Expand All @@ -349,10 +413,14 @@ def apply(
item_idx,
)

placeholders = apply_placeholders(
(
new_token_ids,
new_prompt,
placeholders,
) = apply_placeholders(
tokenizer,
new_token_ids,
prompt,
new_prompt,
match_str,
replacement["token_id"],
replacement_count,
Expand All @@ -365,7 +433,7 @@ def apply(

return MultiModalInputsV2(
type="multimodal",
prompt=prompt,
prompt=new_prompt,
prompt_token_ids=new_token_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
Expand Down

0 comments on commit 4861d54

Please sign in to comment.