Skip to content

Commit

Permalink
Add OpenAIRV random variable
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 9, 2023
1 parent 890ad83 commit 137c64a
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 0 deletions.
2 changes: 2 additions & 0 deletions outlines/text/random/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .openai_rv import openai
from .transformers_rv import transformers
127 changes: 127 additions & 0 deletions outlines/text/random/openai_rv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
from typing import Callable, Dict, List, Union

import numpy as np


class OpenAIRV:
"""Represents a token random variable defined by an OpenAI model."""

def __init__(self, model_name: str):
import tiktoken

self.model_name = model_name
self.tokenizer = tiktoken.encoding_for_model(model_name)

if "text-" in model_name:
self.call_api = call_completion_api
self.format_prompt = lambda x: x
self.extract_choice = lambda x: x["text"]
elif "gpt-" in model_name:
self.call_api = call_chat_completion_api
self.format_prompt = lambda x: {"role": "user", "content": x[0]}
self.extract_choice = lambda x: x["message"]["content"]
else:
raise NameError(
f"The model {model_name} requested is not available. Only the completion and chat completion models are available for OpenAI."
)

async def __call__(self, input_ids: Union[str, List[str]], samples: int = 1):
prompt = self.tokenizer.decode_batch(input_ids)
response = await self.call_api(
self.model_name, self.format_prompt(prompt), 1, 1.0, [], {}, samples
)

results = [self.extract_choice(choice) for choice in response["choices"]]
token_ids = np.array(self.tokenizer.encode_batch(results)).reshape(
len(input_ids), samples, -1
)

return token_ids.squeeze()


def openai(model_name: str):
return OpenAIRV(model_name)


def error_handler(api_call_fn: Callable) -> Callable:
"""Handle OpenAI API errors and missing API key."""

def call(*args, **kwargs):
try:
os.environ["OPENAI_API_KEY"]
except KeyError:
raise OSError(
"Could not find the `OPENAI_API_KEY` environment variable, which is necessary to call "
"OpenAI's APIs. Please make sure it is set before re-running your model."
)

try:
return api_call_fn(*args, **kwargs)
except (
openai.error.RateLimitError,
openai.error.Timeout,
openai.error.TryAgain,
openai.error.APIConnectionError,
openai.error.ServiceUnavailableError,
) as e:
raise OSError(f"Could not connect to the OpenAI API: {e}")
except (
openai.error.AuthenticationError,
openai.error.PermissionError,
openai.error.InvalidRequestError,
openai.error.InvalidAPIType,
) as e:
raise e

return call


@error_handler
async def call_completion_api(
model: str,
prompt: str,
max_tokens: int,
temperature: float,
stop_sequences: List[str],
logit_bias: Dict[str, int],
num_samples: int,
):
import openai

response = await openai.Completion.acreate(
engine=model,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=list(stop_sequences) if len(stop_sequences) > 0 else None,
logit_bias=logit_bias,
n=int(num_samples),
)

return response


@error_handler
async def call_chat_completion_api(
model: str,
messages: List[Dict[str, str]],
max_tokens: int,
temperature: float,
stop_sequences: List[str],
logit_bias: Dict[str, int],
num_samples: int,
):
import openai

response = await openai.ChatCompletion.acreate(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
stop=list(stop_sequences) if len(stop_sequences) > 0 else None,
logit_bias=logit_bias,
n=int(num_samples),
)

return response
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ test = [
"diffusers",
"pre-commit",
"pytest",
"tiktoken",
"torch",
"transformers"
]
Expand Down
126 changes: 126 additions & 0 deletions tests/text/random/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import asyncio
import itertools

import pytest
import tiktoken
from numpy.testing import assert_array_equal

import outlines.text.random as random
from outlines.text.random.openai_rv import OpenAIRV


async def mock_completion_api_call(
model, prompts, max_tokens, temperature, stop_sequence, logit_bias, num_samples
):
"""Mock completion API call.
The returned dictionary was copied from the OpenAI API reference
at https://platform.openai.com/docs/api-reference/completions/create
on 06/09/2023.
"""
choices = [
{"text": f"{p}{s}", "index": 0, "logprobs": None, "finish_reason": "length"}
for p, s in itertools.product(prompts, range(num_samples))
]
return {
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "text-davinci-003",
"choices": choices,
"usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12},
}


async def mock_chat_completion_api_call(
model, message, max_tokens, temperature, stop_sequence, logit_bias, num_samples
):
"""Mock completion API call.
The returned dictionary was copied from the OpenAI API reference
at https://platform.openai.com/docs/api-reference/completions/create
on 06/09/2023.
"""
prompt = message["content"]
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": f"{p}{s}",
},
"finish_reason": "stop",
}
for p, s in itertools.product(prompt, range(num_samples))
]
return {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"choices": choices,
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
}


def test_wrong_name():
with pytest.raises(KeyError):
OpenAIRV("davinci-003")


def test_completion():
rv = random.openai("text-davinci-003")
assert isinstance(rv, OpenAIRV)

tokenizer = tiktoken.encoding_for_model("text-davinci-003")
rv.call_api = mock_completion_api_call

prompt = "A"
input_ids = tokenizer.encode_batch(prompt)
result = asyncio.run(rv(input_ids))
assert result.ndim == 1
assert_array_equal(result, tokenizer.encode("A0"))

result = asyncio.run(rv(input_ids, samples=3))
assert result.shape[0] == 3
assert_array_equal(result, tokenizer.encode_batch(["A0", "A1", "A2"]))


def test_completion_list():
rv = random.openai("text-davinci-003")
assert isinstance(rv, OpenAIRV)

tokenizer = tiktoken.encoding_for_model("text-davinci-003")
rv.call_api = mock_completion_api_call

prompts = ["A", "B"]
input_ids = tokenizer.encode_batch(prompts)

result = asyncio.run(rv(input_ids))
assert result.shape[0] == 2
assert_array_equal(result.reshape(2, -1), tokenizer.encode_batch(["A0", "B0"]))

result = asyncio.run(rv(input_ids, samples=3))
assert result.shape[0] == 2
assert result.shape[1] == 3
assert_array_equal(
result.reshape(6, -1),
tokenizer.encode_batch(["A0", "A1", "A2", "B0", "B1", "B2"]),
)


def test_chat_completion():
rv = random.openai("gpt-3.5-turbo")
assert isinstance(rv, OpenAIRV)

tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
rv.call_api = mock_chat_completion_api_call

prompt = "A"
input_ids = tokenizer.encode_batch(prompt)
result = asyncio.run(rv(input_ids))
assert_array_equal(result, tokenizer.encode("A0"))

result = asyncio.run(rv(input_ids, samples=3))
assert_array_equal(result, tokenizer.encode_batch(["A0", "A1", "A2"]))

0 comments on commit 137c64a

Please sign in to comment.