Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align prompt with tokens #201

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 101 additions & 6 deletions outlines/text/generate/sequence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import List, Optional, Tuple, Union
import itertools
import math
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

import interegular
import torch

from outlines.text.parsing import find_partial_matches


class Sequence:
"""Represents a sequence generation method."""
Expand Down Expand Up @@ -171,6 +177,98 @@ def update_token_ids(

return new_token_ids

def find_boundary_tokens(self, prompt: str) -> Dict[int, List[int]]:
"""Find a list of tokens that cross the prompt boundary."""

vocabulary = {
token_id: self.model.tokenizer.decode([token_id])[0]
for token_id in range(len(self.model.tokenizer.vocabulary))
}
prompt_fsm = interegular.parse_pattern(prompt).to_fsm()

prompt_token_ids, _ = self.model.tokenizer.encode(prompt)
prompt_tokens = self.model.tokenizer.decode(prompt_token_ids[0])

token_idx_in_prompt = [0] + list(
itertools.accumulate([len(t) for t in prompt_tokens])
)[:-1]

boundary_tokens = defaultdict(list)
for token_id, token in vocabulary.items():
pmatches = find_partial_matches(prompt_fsm, token)
for pmatch in pmatches:
end_idx, states = pmatch
if end_idx is not None and states[-1] == len(prompt):
if states[0] in token_idx_in_prompt:
boundary_tokens[token_idx_in_prompt.index(states[0])].append(
token_id
)

return boundary_tokens

def align_prompt_tokens(
self, prompt: Union[str, List[str]], rng: torch.Generator
) -> Tuple[torch.LongTensor, torch.LongTensor]:
"""Align the prompt with the vocabulary."""

prompts = prompt
if isinstance(prompts, str):
prompts = [prompts]

masks = []
truncated_attention_masks = []
truncated_token_idss = []
attention_masks = []
for prompt in prompts:
boundary_tokens = self.find_boundary_tokens(prompt)

token_ids, attention_mask = self.model.tokenizer.encode(prompt)
token_ids = token_ids.to(self.device)
attention_mask = attention_mask.to(self.device)

last_token = min(boundary_tokens.keys())
truncated_token_ids = token_ids[:, :last_token]
truncated_attention_mask = attention_mask[:, :last_token]

allowed_tokens = boundary_tokens[last_token]
mask = torch.full(
(len(self.model.tokenizer.vocabulary),), -math.inf, device=self.device
)
mask[allowed_tokens] = 0

masks.append(mask)
truncated_attention_masks.append(truncated_attention_mask.squeeze())
attention_masks.append(attention_mask.squeeze())
truncated_token_idss.append(truncated_token_ids.squeeze())

# Pad left and stack
from torch.nn.utils.rnn import pad_sequence

mask = torch.vstack(masks)
truncated_attention_mask = pad_sequence(
[t.flip(dims=[0]) for t in truncated_attention_masks],
batch_first=True,
padding_value=0,
).flip(dims=[1])
attention_mask = pad_sequence(
[a.flip(dims=[0]) for a in attention_masks],
batch_first=True,
padding_value=0,
).flip(dims=[1])
truncated_token_ids = pad_sequence(
[t.flip(dims=[0]) for t in truncated_token_idss],
batch_first=True,
padding_value=self.model.tokenizer.pad_token_id,
).flip(dims=[1])

probs = self.model(truncated_token_ids, truncated_attention_mask)
probs = probs + mask
probs = torch.nn.functional.softmax(probs, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1)
token_ids = torch.concatenate([truncated_token_ids, next_token_ids], axis=-1)

return token_ids, attention_mask

@torch.inference_mode()
def __call__(
self,
Expand All @@ -192,14 +290,11 @@ def __call__(
The full sequence that contains the prompts and the generated string.

"""
token_ids, attention_mask = self.model.tokenizer.encode(prompt)

token_ids = token_ids.to(self.device)
attention_mask = attention_mask.to(self.device)

if rng is None:
rng = torch.Generator(device=self.device)

token_ids, attention_mask = self.align_prompt_tokens(prompt, rng)

num_prompt_tokens = token_ids.shape[-1]

if samples > 1:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ module = [
"scipy.*",
"tenacity.*",
"tiktoken.*",
"torch",
"torch.*",
"transformers.*",
"lark.*",
"regex.*",
Expand Down
Loading