Skip to content

Commit

Permalink
Add Completion generation model
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 20, 2023
1 parent cda50ae commit d64980a
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
1 change: 1 addition & 0 deletions outlines/text/sequences/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .completion import completion
45 changes: 45 additions & 0 deletions outlines/text/sequences/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Optional

import numpy as np
from numpy.typing import NDArray

from outlines.text.sequences.sequence import Sequence


class Completion(Sequence):
"""Represents a completion generation model.
`Completion` instances are unconstrained generation models that stop when an EOS token
has been found or when the maximum number of tokens has been reached.
>> import outlines.text as text
>> sequence = text.sequence(model)("Say something")
"""

def __init__(self, model, max_tokens: Optional[int]):
super().__init__(model, max_tokens)

def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]:
"""Determine whether the sequences reached maximum length of end with
and EOS token.
In practice, `Sequence`'s `__call__` methods only passed the `token_ids`
of the sequences that haven't been marked as finished already, which is
why we only need to look for the EOS token in the last element rather
than in the whole sequence.
Parameters
----------
token_ids
The input sequences.
"""
is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_)
is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True

return is_finished


def completion(model, max_tokens: Optional[int] = None):
return Completion(model, max_tokens)
31 changes: 31 additions & 0 deletions tests/text/sequences/test_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
from numpy.testing import assert_array_equal

from outlines.text.sequences.completion import Completion


def test_completion_eos():
class Tokenizer:
eos_token_id = 0
pad_token_ids = -1

class Model:
tokenizer = Tokenizer()

model = Completion(Model(), 10)

token_ids = np.array([[3, 2]])
result = model.is_finished(token_ids)
assert_array_equal(result, [False])

token_ids = np.array([[3, 2, 0]])
result = model.is_finished(token_ids)
assert_array_equal(result, [True])

token_ids = np.array([[3, 2, 1], [3, 2, 0]])
result = model.is_finished(token_ids)
assert_array_equal(result, [False, True])

token_ids = np.array([[3, 2, 1, 0], [3, 2, 0, -1]])
result = model.is_finished(token_ids)
assert_array_equal(result, [True, False])

0 comments on commit d64980a

Please sign in to comment.