Skip to content

Commit

Permalink
add rejection sampling to CFGLogitsProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jul 26, 2024
1 parent 21d61d1 commit 19129e5
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 17 deletions.
8 changes: 6 additions & 2 deletions benchmarks/bench_cfg_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ def setup(self, grammar_name):
@staticmethod
def _run_random_cfg(guide):
state = guide.initial_state

for i in range(40):
next_instruction = guide.get_next_instruction(state)
next_token_id = random.choice(next_instruction.tokens)
# simulate ordering of logits top prob to lowest prob
token_ids = list(range(guide.tokenizer.vocabulary))
random.shuffle(token_ids)
# simulate sampling and state update
next_token_id = next(guide.iter_valid_token_ids(state, token_ids))
state = guide.get_next_state(state, next_token_id)

@cache_disabled()
Expand Down
49 changes: 39 additions & 10 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Any,
Callable,
Dict,
Generator,
List,
Optional,
Protocol,
Expand Down Expand Up @@ -351,26 +352,54 @@ def get_next_instruction(
if parser_state is None:
return Write(torch.tensor([self.eos_token_id]))

valid_tokens = []
for test_token, token_id in self.tokenizer.vocabulary.items():
valid_tokens = list(
self.iter_valid_token_ids(parser_state, self.tokenizer.vocabulary.values())
)
if len(valid_tokens) == 1:
return Write(torch.tensor(valid_tokens))
return Generate(torch.tensor(valid_tokens))

def iter_valid_token_ids(
self, parser_state: Optional[PartialParserState], token_ids: list
) -> Generator[int, None, None]:
"""
Iterate over the given token_ids and yield those that are valid for the current parser state.
Parameters
----------
parser_state
The current state of the parser, or None if complete.
token_ids
The list of token ids to check for validity.
Yields
------
int
Valid token ids.
"""
if parser_state is None:
yield self.eos_token_id
return

for token_id in token_ids:
if token_id == self.eos_token_id:
if self.can_terminate_state(parser_state):
valid_tokens.append(token_id)

yield token_id
else:
ps = copy.copy(parser_state)
ls = ps.lexer.state
ls.text += self.tokenizer.convert_token_to_string(test_token)
token_str = self.tokenizer.convert_token_to_string(
self.tokenizer.decode([token_id])[0]
)
if token_str == "":
continue
ls.text += token_str
try:
self.parser.parse_from_state(ps, is_end=False)
valid_tokens.append(token_id)
yield token_id
except (EOFError, UnexpectedToken, UnexpectedCharacters, DedentError):
pass

if len(valid_tokens) == 1:
return Write(torch.tensor(valid_tokens))
return Generate(torch.tensor(valid_tokens))

def get_next_state(
self, parser_state: Optional[PartialParserState], token_id: int
) -> Optional[PartialParserState]:
Expand Down
46 changes: 42 additions & 4 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
limitations under the License.
"""
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union

import torch
from pydantic import BaseModel
Expand All @@ -50,6 +50,11 @@ class GuideLogitsProcessor(OutlinesLogitsProcessor):
The `outlines.fsm.Guide` which is used to bias the logits.
"""

tokenizer: "Tokenizer"
guide: Guide
_guide_states: Dict[int, Any]
_seq_start_idx: Optional[int]

def __init__(self, tokenizer: "Tokenizer", guide: Guide):
"""A Guide-based logits processor.
Expand All @@ -61,9 +66,9 @@ def __init__(self, tokenizer: "Tokenizer", guide: Guide):
The `outlines.fsm.Guide. which is used to bias the logits.
"""
self.tokenizer = tokenizer
self.guide: Guide = guide
self._guide_states: Dict[int, int] = {hash(tuple([])): self.guide.initial_state}
self._seq_start_idx: Optional[int] = None
self.guide = guide
self._guide_states = {hash(tuple([])): self.guide.initial_state}
self._seq_start_idx = None

def process_logits(
self, input_ids: List[List[int]], logits: torch.Tensor
Expand Down Expand Up @@ -181,6 +186,8 @@ class CFGLogitsProcessor(GuideLogitsProcessor):
The `outlines.fsm.CFGGuide. which is used to bias the logits.
"""

guide: CFGGuide

def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
"""Compile the CFGGuide that drives the CFG-guided generation.
Expand All @@ -193,3 +200,34 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
"""
cfg_guide = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer)
super().__init__(tokenizer=tokenizer, guide=cfg_guide)

def process_logits(
self, input_ids: List[List[int]], logits: torch.Tensor
) -> torch.Tensor:
"""Same behavior as GuideLogitsProcessor, but uses rejection sampling"""
if self._seq_start_idx is None:
self._seq_start_idx = len(input_ids[0])

sequence_states: List = [] # vector of states corresponding to `input_ids`

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))]
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1])
self._guide_states[curr_state_key] = curr_state

sequence_states.append(self._guide_states[curr_state_key])

mask = torch.full_like(logits, -math.inf)
for i, guide_state in enumerate(sequence_states):
first_legal_token = next(
self.guide.iter_valid_token_ids(
guide_state, torch.argsort(logits[i], descending=True)
)
)
mask[i, [first_legal_token]] = logits[i, [first_legal_token]]

return mask
2 changes: 2 additions & 0 deletions tests/fsm/test_cfg_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ def test_cfg_grammar_sample(request, sample_name, tokenizer_name, cleanup_lark_i

state = cfg_guide.initial_state
for i, token_id in enumerate(sample_token_ids):
if tokenizer.decode(token_id)[0] == "":
continue
next_instruction = cfg_guide.get_next_instruction(state)
if token_id not in next_instruction.tokens:
processed_str = tokenizer.decode([sample_token_ids[:i]])[0]
Expand Down
2 changes: 1 addition & 1 deletion tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def sample_choices():
def sample_lark_grammar():
# from https://github.com/lark-parser/lark/blob/master/docs/grammar.md
return """
?start: (hello_world | number)
?start: hello_world "!" number
hello_world: ("hello" | "world") ~ 3
number: ("0".."9") ~ 5
thanks: "Thank"i " for testing!"
Expand Down

0 comments on commit 19129e5

Please sign in to comment.