From 137c64a407878f1c3b96b97581c07fc65fdda68b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 9 Jun 2023 11:02:46 +0200 Subject: [PATCH] Add `OpenAIRV` random variable --- outlines/text/random/__init__.py | 2 + outlines/text/random/openai_rv.py | 127 ++++++++++++++++++++++++++++++ pyproject.toml | 1 + tests/text/random/test_openai.py | 126 +++++++++++++++++++++++++++++ 4 files changed, 256 insertions(+) create mode 100644 outlines/text/random/__init__.py create mode 100644 outlines/text/random/openai_rv.py create mode 100644 tests/text/random/test_openai.py diff --git a/outlines/text/random/__init__.py b/outlines/text/random/__init__.py new file mode 100644 index 000000000..c770b8637 --- /dev/null +++ b/outlines/text/random/__init__.py @@ -0,0 +1,2 @@ +from .openai_rv import openai +from .transformers_rv import transformers diff --git a/outlines/text/random/openai_rv.py b/outlines/text/random/openai_rv.py new file mode 100644 index 000000000..23a9f943c --- /dev/null +++ b/outlines/text/random/openai_rv.py @@ -0,0 +1,127 @@ +import os +from typing import Callable, Dict, List, Union + +import numpy as np + + +class OpenAIRV: + """Represents a token random variable defined by an OpenAI model.""" + + def __init__(self, model_name: str): + import tiktoken + + self.model_name = model_name + self.tokenizer = tiktoken.encoding_for_model(model_name) + + if "text-" in model_name: + self.call_api = call_completion_api + self.format_prompt = lambda x: x + self.extract_choice = lambda x: x["text"] + elif "gpt-" in model_name: + self.call_api = call_chat_completion_api + self.format_prompt = lambda x: {"role": "user", "content": x[0]} + self.extract_choice = lambda x: x["message"]["content"] + else: + raise NameError( + f"The model {model_name} requested is not available. Only the completion and chat completion models are available for OpenAI." + ) + + async def __call__(self, input_ids: Union[str, List[str]], samples: int = 1): + prompt = self.tokenizer.decode_batch(input_ids) + response = await self.call_api( + self.model_name, self.format_prompt(prompt), 1, 1.0, [], {}, samples + ) + + results = [self.extract_choice(choice) for choice in response["choices"]] + token_ids = np.array(self.tokenizer.encode_batch(results)).reshape( + len(input_ids), samples, -1 + ) + + return token_ids.squeeze() + + +def openai(model_name: str): + return OpenAIRV(model_name) + + +def error_handler(api_call_fn: Callable) -> Callable: + """Handle OpenAI API errors and missing API key.""" + + def call(*args, **kwargs): + try: + os.environ["OPENAI_API_KEY"] + except KeyError: + raise OSError( + "Could not find the `OPENAI_API_KEY` environment variable, which is necessary to call " + "OpenAI's APIs. Please make sure it is set before re-running your model." + ) + + try: + return api_call_fn(*args, **kwargs) + except ( + openai.error.RateLimitError, + openai.error.Timeout, + openai.error.TryAgain, + openai.error.APIConnectionError, + openai.error.ServiceUnavailableError, + ) as e: + raise OSError(f"Could not connect to the OpenAI API: {e}") + except ( + openai.error.AuthenticationError, + openai.error.PermissionError, + openai.error.InvalidRequestError, + openai.error.InvalidAPIType, + ) as e: + raise e + + return call + + +@error_handler +async def call_completion_api( + model: str, + prompt: str, + max_tokens: int, + temperature: float, + stop_sequences: List[str], + logit_bias: Dict[str, int], + num_samples: int, +): + import openai + + response = await openai.Completion.acreate( + engine=model, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + stop=list(stop_sequences) if len(stop_sequences) > 0 else None, + logit_bias=logit_bias, + n=int(num_samples), + ) + + return response + + +@error_handler +async def call_chat_completion_api( + model: str, + messages: List[Dict[str, str]], + max_tokens: int, + temperature: float, + stop_sequences: List[str], + logit_bias: Dict[str, int], + num_samples: int, +): + import openai + + response = await openai.ChatCompletion.acreate( + model=model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + stop=list(stop_sequences) if len(stop_sequences) > 0 else None, + logit_bias=logit_bias, + n=int(num_samples), + ) + + return response diff --git a/pyproject.toml b/pyproject.toml index 745c3e499..6b68305ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ test = [ "diffusers", "pre-commit", "pytest", + "tiktoken", "torch", "transformers" ] diff --git a/tests/text/random/test_openai.py b/tests/text/random/test_openai.py new file mode 100644 index 000000000..487b0d033 --- /dev/null +++ b/tests/text/random/test_openai.py @@ -0,0 +1,126 @@ +import asyncio +import itertools + +import pytest +import tiktoken +from numpy.testing import assert_array_equal + +import outlines.text.random as random +from outlines.text.random.openai_rv import OpenAIRV + + +async def mock_completion_api_call( + model, prompts, max_tokens, temperature, stop_sequence, logit_bias, num_samples +): + """Mock completion API call. + + The returned dictionary was copied from the OpenAI API reference + at https://platform.openai.com/docs/api-reference/completions/create + on 06/09/2023. + + """ + choices = [ + {"text": f"{p}{s}", "index": 0, "logprobs": None, "finish_reason": "length"} + for p, s in itertools.product(prompts, range(num_samples)) + ] + return { + "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", + "created": 1589478378, + "model": "text-davinci-003", + "choices": choices, + "usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}, + } + + +async def mock_chat_completion_api_call( + model, message, max_tokens, temperature, stop_sequence, logit_bias, num_samples +): + """Mock completion API call. + + The returned dictionary was copied from the OpenAI API reference + at https://platform.openai.com/docs/api-reference/completions/create + on 06/09/2023. + + """ + prompt = message["content"] + choices = [ + { + "index": 0, + "message": { + "role": "assistant", + "content": f"{p}{s}", + }, + "finish_reason": "stop", + } + for p, s in itertools.product(prompt, range(num_samples)) + ] + return { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "choices": choices, + "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, + } + + +def test_wrong_name(): + with pytest.raises(KeyError): + OpenAIRV("davinci-003") + + +def test_completion(): + rv = random.openai("text-davinci-003") + assert isinstance(rv, OpenAIRV) + + tokenizer = tiktoken.encoding_for_model("text-davinci-003") + rv.call_api = mock_completion_api_call + + prompt = "A" + input_ids = tokenizer.encode_batch(prompt) + result = asyncio.run(rv(input_ids)) + assert result.ndim == 1 + assert_array_equal(result, tokenizer.encode("A0")) + + result = asyncio.run(rv(input_ids, samples=3)) + assert result.shape[0] == 3 + assert_array_equal(result, tokenizer.encode_batch(["A0", "A1", "A2"])) + + +def test_completion_list(): + rv = random.openai("text-davinci-003") + assert isinstance(rv, OpenAIRV) + + tokenizer = tiktoken.encoding_for_model("text-davinci-003") + rv.call_api = mock_completion_api_call + + prompts = ["A", "B"] + input_ids = tokenizer.encode_batch(prompts) + + result = asyncio.run(rv(input_ids)) + assert result.shape[0] == 2 + assert_array_equal(result.reshape(2, -1), tokenizer.encode_batch(["A0", "B0"])) + + result = asyncio.run(rv(input_ids, samples=3)) + assert result.shape[0] == 2 + assert result.shape[1] == 3 + assert_array_equal( + result.reshape(6, -1), + tokenizer.encode_batch(["A0", "A1", "A2", "B0", "B1", "B2"]), + ) + + +def test_chat_completion(): + rv = random.openai("gpt-3.5-turbo") + assert isinstance(rv, OpenAIRV) + + tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") + rv.call_api = mock_chat_completion_api_call + + prompt = "A" + input_ids = tokenizer.encode_batch(prompt) + result = asyncio.run(rv(input_ids)) + assert_array_equal(result, tokenizer.encode("A0")) + + result = asyncio.run(rv(input_ids, samples=3)) + assert_array_equal(result, tokenizer.encode_batch(["A0", "A1", "A2"]))