diff --git a/outlines/text/sequences/__init__.py b/outlines/text/sequences/__init__.py new file mode 100644 index 000000000..a9bbd59ca --- /dev/null +++ b/outlines/text/sequences/__init__.py @@ -0,0 +1 @@ +from .completion import completion diff --git a/outlines/text/sequences/completion.py b/outlines/text/sequences/completion.py new file mode 100644 index 000000000..7fd427af3 --- /dev/null +++ b/outlines/text/sequences/completion.py @@ -0,0 +1,45 @@ +from typing import Optional + +import numpy as np +from numpy.typing import NDArray + +from outlines.text.sequences.sequence import Sequence + + +class Completion(Sequence): + """Represents a completion generation model. + + `Completion` instances are unconstrained generation models that stop when an EOS token + has been found or when the maximum number of tokens has been reached. + + >> import outlines.text as text + >> sequence = text.sequence(model)("Say something") + + """ + + def __init__(self, model, max_tokens: Optional[int]): + super().__init__(model, max_tokens) + + def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]: + """Determine whether the sequences reached maximum length of end with + and EOS token. + + In practice, `Sequence`'s `__call__` methods only passed the `token_ids` + of the sequences that haven't been marked as finished already, which is + why we only need to look for the EOS token in the last element rather + than in the whole sequence. + + Parameters + ---------- + token_ids + The input sequences. + + """ + is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_) + is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True + + return is_finished + + +def completion(model, max_tokens: Optional[int] = None): + return Completion(model, max_tokens) diff --git a/tests/text/sequences/test_completion.py b/tests/text/sequences/test_completion.py new file mode 100644 index 000000000..6fe594418 --- /dev/null +++ b/tests/text/sequences/test_completion.py @@ -0,0 +1,31 @@ +import numpy as np +from numpy.testing import assert_array_equal + +from outlines.text.sequences.completion import Completion + + +def test_completion_eos(): + class Tokenizer: + eos_token_id = 0 + pad_token_ids = -1 + + class Model: + tokenizer = Tokenizer() + + model = Completion(Model(), 10) + + token_ids = np.array([[3, 2]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [False]) + + token_ids = np.array([[3, 2, 0]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [True]) + + token_ids = np.array([[3, 2, 1], [3, 2, 0]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [False, True]) + + token_ids = np.array([[3, 2, 1, 0], [3, 2, 0, -1]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [True, False])