diff --git a/outlines/text/sequences/__init__.py b/outlines/text/sequences/__init__.py new file mode 100644 index 000000000..a9bbd59ca --- /dev/null +++ b/outlines/text/sequences/__init__.py @@ -0,0 +1 @@ +from .completion import completion diff --git a/outlines/text/sequences/completion.py b/outlines/text/sequences/completion.py new file mode 100644 index 000000000..c105324dd --- /dev/null +++ b/outlines/text/sequences/completion.py @@ -0,0 +1,52 @@ +from typing import List, Optional + +import numpy as np +from numpy.typing import NDArray + +from outlines.text.sequences.sequence import Sequence + + +class Completion(Sequence): + """Represents a completion generation model. + + `Completion` instances are unconstrained generation models that stop when an EOS token + has been found or when the maximum number of tokens has been reached. + + >> import outlines.text as text + >> sequence = text.sequence(model)("Say something") + + """ + + def __init__(self, model, max_tokens: Optional[int]): + super().__init__(model, max_tokens) + + def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]: + """Determine whether the sequences reached maximum length of end with + and EOS token. + + In practice, `Sequence`'s `__call__` methods only passed the `token_ids` + of the sequences that haven't been marked as finished already, which is + why we only need to look for the EOS token in the last element rather + than in the whole sequence. + + Parameters + ---------- + token_ids + The input sequences. + + """ + is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_) + is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True + + return is_finished + + def postprocess_completions(self, completions: List[str]) -> List[str]: + """Remove the EOS token from the completion.""" + return [ + completion.replace(self.model.tokenizer.eos_token, "") + for completion in completions + ] + + +def completion(model, max_tokens: Optional[int] = None): + return Completion(model, max_tokens) diff --git a/outlines/text/sequences/sequence.py b/outlines/text/sequences/sequence.py index a5f975cc1..d42e97c00 100644 --- a/outlines/text/sequences/sequence.py +++ b/outlines/text/sequences/sequence.py @@ -29,6 +29,9 @@ def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]: "`Sequence.is_finished` must be implemented by subclasses." ) + def postprocess_completions(self, completions: List[str]) -> List[str]: + return completions + def step( self, rng: Generator, @@ -204,6 +207,7 @@ def __call__( is_finished[~is_finished] = self.is_finished(token_ids_unfinished).flatten() result = self.model.tokenizer.decode(token_ids) + result = self.postprocess_completions(result) if len(result) == 1: return result[0] diff --git a/tests/text/sequences/test_completion.py b/tests/text/sequences/test_completion.py new file mode 100644 index 000000000..56acd4d4f --- /dev/null +++ b/tests/text/sequences/test_completion.py @@ -0,0 +1,42 @@ +import numpy as np +from numpy.testing import assert_array_equal + +from outlines.text.sequences.completion import Completion, completion + + +class Tokenizer: + eos_token = "" + eos_token_id = 0 + pad_token_ids = -1 + + +class Model: + tokenizer = Tokenizer() + + +def test_completion_is_finished(): + model = completion(Model(), 10) + assert isinstance(model, Completion) + + token_ids = np.array([[3, 2]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [False]) + + token_ids = np.array([[3, 2, 0]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [True]) + + token_ids = np.array([[3, 2, 1], [3, 2, 0]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [False, True]) + + token_ids = np.array([[3, 2, 1, 0], [3, 2, 0, -1]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [True, False]) + + +def test_completion_postprocess(): + model = completion(Model()) + result = model.postprocess_completions(["Here"]) + assert len(result) == 1 + assert result[0] == "Here" diff --git a/tests/text/sequences/test_integration_transfomers.py b/tests/text/sequences/test_integration_transfomers.py new file mode 100644 index 000000000..b608fe0a8 --- /dev/null +++ b/tests/text/sequences/test_integration_transfomers.py @@ -0,0 +1,18 @@ +import numpy as np + +import outlines.models as models +from outlines.text.sequences.completion import completion + +TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM" + + +def test_transformers_integration_completion(): + rng = np.random.default_rng(0) + + model = models.transformers(TEST_MODEL, device="cpu") + sequence = completion(model)("prompt", rng=rng) + assert isinstance(sequence, str) + assert model.tokenizer.eos_token not in sequence + + sequence = completion(model, max_tokens=10)("prompt", rng=rng) + assert isinstance(sequence, str)