Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Transformers model and Completion sequence generation #139

Merged
merged 5 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .hf_diffusers import HuggingFaceDiffuser
from .hf_transformers import HuggingFaceCompletion
from .openai import OpenAICompletion, OpenAIEmbeddings, OpenAIImageGeneration
from .transformers import transformers
23 changes: 23 additions & 0 deletions outlines/models/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import abstractmethod
from typing import List, Protocol, Tuple, Union

import numpy as np
from numpy.typing import NDArray


class Tokenizer(Protocol):
eos_token: str
eos_token_id: int
pad_token_id: int

@abstractmethod
def encode(
self, prompt: Union[str, List[str]]
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
...
rlouf marked this conversation as resolved.
Show resolved Hide resolved

@abstractmethod
def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
"""Translate an array of token ids to a string or list of strings."""
...
rlouf marked this conversation as resolved.
Show resolved Hide resolved
92 changes: 92 additions & 0 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import math
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
from numpy.typing import NDArray

from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer


__all__ = ["transformers"]


class Transformers:
rlouf marked this conversation as resolved.
Show resolved Hide resolved
"""Represents a `transformers` model."""

def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
device: Optional[str] = None,
):
self.device = device if device is not None else "cpu"
self.model = model.to(self.device)
self.tokenizer = tokenizer

def __call__(
self, input_ids: NDArray[np.int64], attention_mask: NDArray[np.int64]
) -> NDArray[np.float64]:
import torch

# `transformers` model accept `input_ids` of size at most equal to 2. We
# thus reshape the input array, call the model and reshape the output
# logits.
batch_shape = input_ids.shape[:-1]
num_tokens = input_ids.shape[-1]
input_ids = input_ids.reshape(math.prod(batch_shape), num_tokens)

with torch.no_grad():
input_ids = torch.from_numpy(input_ids).to(self.device)
attention_mask = torch.from_numpy(attention_mask).to(self.device)

output = self.model(input_ids, attention_mask=attention_mask)

next_token_logits = output.logits[:, -1, :]
probs = torch.nn.functional.softmax(next_token_logits, dim=-1).squeeze()
probs = torch.atleast_2d(probs)
numpy_probs = probs.cpu().detach().numpy()

return numpy_probs.reshape(batch_shape + (-1,))


class TransformersTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""

def __init__(self, model_name: str, **kwargs):
from transformers import AutoTokenizer

kwargs.setdefault("padding_side", "left")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token

if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
kwargs["padding"] = True
kwargs["return_tensors"] = "np"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
text = self.tokenizer.batch_decode(token_ids)
return text


def transformers(model_name: str, device: Optional[str] = None, **model_kwargs):
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformersTokenizer(model_name)

return Transformers(model, tokenizer, device)
1 change: 1 addition & 0 deletions outlines/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .functions import function
from .generate import continuation
from .prompts import prompt, render
1 change: 1 addition & 0 deletions outlines/text/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .continuation import continuation
52 changes: 52 additions & 0 deletions outlines/text/generate/continuation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List, Optional

import numpy as np
from numpy.typing import NDArray

from outlines.text.generate.sequence import Sequence


class Continuation(Sequence):
"""Represents a completion generation model.

`Completion` instances are unconstrained generation models that stop when an EOS token
has been found or when the maximum number of tokens has been reached.

>> import outlines.text as text
>> sequence = text.sequence(model)("Say something")

"""

def __init__(self, model, max_tokens: Optional[int]):
super().__init__(model, max_tokens)

def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]:
"""Determine whether the sequences reached maximum length of end with
and EOS token.

In practice, `Sequence`'s `__call__` methods only passed the `token_ids`
of the sequences that haven't been marked as finished already, which is
why we only need to look for the EOS token in the last element rather
than in the whole sequence.

Parameters
----------
token_ids
The input sequences.

"""
is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_)
is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True

return is_finished

def postprocess_completions(self, completions: List[str]) -> List[str]:
"""Remove the EOS token from the completion."""
return [
completion.replace(self.model.tokenizer.eos_token, "")
for completion in completions
]


def continuation(model, max_tokens: Optional[int] = None):
return Continuation(model, max_tokens)
Loading