From 0abf64137c18b45925d5015bae80429adb46fac6 Mon Sep 17 00:00:00 2001 From: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> Date: Wed, 4 Jan 2023 19:39:57 -0500 Subject: [PATCH] Add async support (#146) * Add async support * Fix aiohttp requests * Fix some syntax errors * Close aiohttp session properly * This is due to a lack of an async __del__ method * Fix code per review * Fix async tests and some mypy errors * Run black * Add todo for multipart form generation * Fix more mypy * Fix exception type * Don't yield twice Co-authored-by: Damien Deville --- README.md | 26 ++ openai/__init__.py | 11 +- openai/api_requestor.py | 255 ++++++++++++++++-- openai/api_resources/abstract/api_resource.py | 43 +++ .../abstract/createable_api_resource.py | 55 +++- .../abstract/deletable_api_resource.py | 16 +- .../abstract/engine_api_resource.py | 121 ++++++++- .../abstract/listable_api_resource.py | 52 +++- .../abstract/nested_resource_class_methods.py | 58 +++- .../abstract/updateable_api_resource.py | 6 + openai/api_resources/answer.py | 5 + openai/api_resources/classification.py | 5 + openai/api_resources/completion.py | 20 ++ openai/api_resources/customer.py | 5 + openai/api_resources/deployment.py | 49 +++- openai/api_resources/edit.py | 24 ++ openai/api_resources/embedding.py | 39 +++ openai/api_resources/engine.py | 25 ++ openai/api_resources/file.py | 157 +++++++++-- openai/api_resources/fine_tune.py | 71 ++++- openai/api_resources/image.py | 111 +++++++- openai/api_resources/moderation.py | 22 +- openai/api_resources/search.py | 21 ++ openai/embeddings_utils.py | 27 ++ openai/openai_object.py | 51 ++++ openai/tests/__init__.py | 0 openai/tests/asyncio/__init__.py | 0 openai/tests/asyncio/test_endpoints.py | 65 +++++ openai/tests/test_long_examples_validator.py | 19 +- setup.py | 3 +- 30 files changed, 1288 insertions(+), 74 deletions(-) create mode 100644 openai/tests/__init__.py create mode 100644 openai/tests/asyncio/__init__.py create mode 100644 openai/tests/asyncio/test_endpoints.py diff --git a/README.md b/README.md index 5bd40a7919..53bab3ab2a 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,32 @@ image_resp = openai.Image.create(prompt="two dogs playing chess, oil painting", ``` +## Async API + +Async support is available in the API by prepending `a` to a network-bound method: + +```python +import openai +openai.api_key = "sk-..." # supply your API key however you choose + +async def create_completion(): + completion_resp = await openai.Completion.acreate(prompt="This is a test", engine="davinci") + +``` + +To make async requests more efficient, you can pass in your own +``aiohttp.ClientSession``, but you must manually close the client session at the end +of your program/event loop: + +```python +import openai +from aiohttp import ClientSession + +openai.aiosession.set(ClientSession()) +# At the end of your program, close the http session +await openai.aiosession.get().close() +``` + See the [usage guide](https://beta.openai.com/docs/guides/images) for more details. ## Requirements diff --git a/openai/__init__.py b/openai/__init__.py index d935ea8ca5..ef6da5ba58 100644 --- a/openai/__init__.py +++ b/openai/__init__.py @@ -3,7 +3,8 @@ # Originally forked from the MIT-licensed Stripe Python bindings. import os -from typing import Optional +from contextvars import ContextVar +from typing import Optional, TYPE_CHECKING from openai.api_resources import ( Answer, @@ -24,6 +25,9 @@ ) from openai.error import APIError, InvalidRequestError, OpenAIError +if TYPE_CHECKING: + from aiohttp import ClientSession + api_key = os.environ.get("OPENAI_API_KEY") # Path of a file with an API key, whose contents can change. Supercedes # `api_key` if set. The main use case is volume-mounted Kubernetes secrets, @@ -44,6 +48,11 @@ debug = False log = None # Set to either 'debug' or 'info', controls console logging +aiosession: ContextVar[Optional["ClientSession"]] = ContextVar( + "aiohttp-session", default=None +) # Acts as a global aiohttp ClientSession that reuses connections. +# This is user-supplied; otherwise, a session is remade for each request. + __all__ = [ "APIError", "Answer", diff --git a/openai/api_requestor.py b/openai/api_requestor.py index 2ae0cbe034..b10730216d 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -1,12 +1,14 @@ +import asyncio import json import platform import sys import threading import warnings from json import JSONDecodeError -from typing import Dict, Iterator, Optional, Tuple, Union, overload +from typing import AsyncGenerator, Dict, Iterator, Optional, Tuple, Union, overload from urllib.parse import urlencode, urlsplit, urlunsplit +import aiohttp import requests if sys.version_info >= (3, 8): @@ -49,6 +51,20 @@ def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]: ) +def _aiohttp_proxies_arg(proxy) -> Optional[str]: + """Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request.""" + if proxy is None: + return None + elif isinstance(proxy, str): + return proxy + elif isinstance(proxy, dict): + return proxy["https"] if "https" in proxy else proxy["http"] + else: + raise ValueError( + "'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys." + ) + + def _make_session() -> requests.Session: if not openai.verify_ssl_certs: warnings.warn("verify_ssl_certs is ignored; openai always verifies.") @@ -63,18 +79,32 @@ def _make_session() -> requests.Session: return s +def parse_stream_helper(line): + if line: + if line == b"data: [DONE]": + # return here will cause GeneratorExit exception in urllib3 + # and it will close http connection with TCP Reset + return None + if hasattr(line, "decode"): + line = line.decode("utf-8") + if line.startswith("data: "): + line = line[len("data: ") :] + return line + return None + + def parse_stream(rbody): for line in rbody: - if line: - if line == b"data: [DONE]": - # return here will cause GeneratorExit exception in urllib3 - # and it will close http connection with TCP Reset - continue - if hasattr(line, "decode"): - line = line.decode("utf-8") - if line.startswith("data: "): - line = line[len("data: ") :] - yield line + _line = parse_stream_helper(line) + if _line is not None: + yield _line + + +async def parse_stream_async(rbody: aiohttp.StreamReader): + async for line in rbody: + _line = parse_stream_helper(line) + if _line is not None: + yield _line class APIRequestor: @@ -186,6 +216,86 @@ def request( resp, got_stream = self._interpret_response(result, stream) return resp, got_stream, self.api_key + @overload + async def arequest( + self, + method, + url, + params, + headers, + files, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]: + pass + + @overload + async def arequest( + self, + method, + url, + params=..., + headers=..., + files=..., + *, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]: + pass + + @overload + async def arequest( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: Literal[False] = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[OpenAIResponse, bool, str]: + pass + + @overload + async def arequest( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: bool = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: + pass + + async def arequest( + self, + method, + url, + params=None, + headers=None, + files=None, + stream: bool = False, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: + result = await self.arequest_raw( + method.lower(), + url, + params=params, + supplied_headers=headers, + files=files, + request_id=request_id, + request_timeout=request_timeout, + ) + resp, got_stream = await self._interpret_async_response(result, stream) + return resp, got_stream, self.api_key + def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False): try: error_data = resp["error"] @@ -315,18 +425,15 @@ def _validate_headers( return headers - def request_raw( + def _prepare_request_raw( self, - method, url, - *, - params=None, - supplied_headers: Dict[str, str] = None, - files=None, - stream: bool = False, - request_id: Optional[str] = None, - request_timeout: Optional[Union[float, Tuple[float, float]]] = None, - ) -> requests.Response: + supplied_headers, + method, + params, + files, + request_id: Optional[str], + ) -> Tuple[str, Dict[str, str], Optional[bytes]]: abs_url = "%s%s" % (self.api_base, url) headers = self._validate_headers(supplied_headers) @@ -355,6 +462,24 @@ def request_raw( util.log_info("Request to OpenAI API", method=method, path=abs_url) util.log_debug("Post details", data=data, api_version=self.api_version) + return abs_url, headers, data + + def request_raw( + self, + method, + url, + *, + params=None, + supplied_headers: Optional[Dict[str, str]] = None, + files=None, + stream: bool = False, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> requests.Response: + abs_url, headers, data = self._prepare_request_raw( + url, supplied_headers, method, params, files, request_id + ) + if not hasattr(_thread_context, "session"): _thread_context.session = _make_session() try: @@ -385,6 +510,71 @@ def request_raw( ) return result + async def arequest_raw( + self, + method, + url, + *, + params=None, + supplied_headers: Optional[Dict[str, str]] = None, + files=None, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> aiohttp.ClientResponse: + abs_url, headers, data = self._prepare_request_raw( + url, supplied_headers, method, params, files, request_id + ) + + if isinstance(request_timeout, tuple): + timeout = aiohttp.ClientTimeout( + connect=request_timeout[0], + total=request_timeout[1], + ) + else: + timeout = aiohttp.ClientTimeout( + total=request_timeout if request_timeout else TIMEOUT_SECS + ) + user_set_session = openai.aiosession.get() + + if files: + # TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here. + # For now we use the private `requests` method that is known to have worked so far. + data, content_type = requests.models.RequestEncodingMixin._encode_files( # type: ignore + files, data + ) + headers["Content-Type"] = content_type + request_kwargs = { + "method": method, + "url": abs_url, + "headers": headers, + "data": data, + "proxy": _aiohttp_proxies_arg(openai.proxy), + "timeout": timeout, + } + try: + if user_set_session: + result = await user_set_session.request(**request_kwargs) + else: + async with aiohttp.ClientSession() as session: + result = await session.request(**request_kwargs) + util.log_info( + "OpenAI API response", + path=abs_url, + response_code=result.status, + processing_ms=result.headers.get("OpenAI-Processing-Ms"), + request_id=result.headers.get("X-Request-Id"), + ) + # Don't read the whole stream for debug logging unless necessary. + if openai.log == "debug": + util.log_debug( + "API response body", body=result.content, headers=result.headers + ) + return result + except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e: + raise error.Timeout("Request timed out") from e + except aiohttp.ClientError as e: + raise error.APIConnectionError("Error communicating with OpenAI") from e + def _interpret_response( self, result: requests.Response, stream: bool ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: @@ -404,6 +594,29 @@ def _interpret_response( False, ) + async def _interpret_async_response( + self, result: aiohttp.ClientResponse, stream: bool + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: + """Returns the response(s) and a bool indicating whether it is a stream.""" + if stream and "text/event-stream" in result.headers.get("Content-Type", ""): + return ( + self._interpret_response_line( + line, result.status, result.headers, stream=True + ) + async for line in parse_stream_async(result.content) + ), True + else: + try: + await result.read() + except aiohttp.ClientError as e: + util.log_warn(e, body=result.content) + return ( + self._interpret_response_line( + await result.read(), result.status, result.headers, stream=False + ), + False, + ) + def _interpret_response_line( self, rbody, rcode, rheaders, stream: bool ) -> OpenAIResponse: diff --git a/openai/api_resources/abstract/api_resource.py b/openai/api_resources/abstract/api_resource.py index aa7cfe88e1..53a7dec799 100644 --- a/openai/api_resources/abstract/api_resource.py +++ b/openai/api_resources/abstract/api_resource.py @@ -20,6 +20,13 @@ def retrieve( instance.refresh(request_id=request_id, request_timeout=request_timeout) return instance + @classmethod + def aretrieve( + cls, id, api_key=None, request_id=None, request_timeout=None, **params + ): + instance = cls(id, api_key, **params) + return instance.arefresh(request_id=request_id, request_timeout=request_timeout) + def refresh(self, request_id=None, request_timeout=None): self.refresh_from( self.request( @@ -31,6 +38,17 @@ def refresh(self, request_id=None, request_timeout=None): ) return self + async def arefresh(self, request_id=None, request_timeout=None): + self.refresh_from( + await self.arequest( + "get", + self.instance_url(operation="refresh"), + request_id=request_id, + request_timeout=request_timeout, + ) + ) + return self + @classmethod def class_url(cls): if cls == APIResource: @@ -116,6 +134,31 @@ def _static_request( response, api_key, api_version, organization ) + @classmethod + async def _astatic_request( + cls, + method_, + url_, + api_key=None, + api_base=None, + api_type=None, + request_id=None, + api_version=None, + organization=None, + **params, + ): + requestor = api_requestor.APIRequestor( + api_key, + api_version=api_version, + organization=organization, + api_base=api_base, + api_type=api_type, + ) + response, _, api_key = await requestor.arequest( + method_, url_, params, request_id=request_id + ) + return response + @classmethod def _get_api_type_and_version( cls, api_type: Optional[str] = None, api_version: Optional[str] = None diff --git a/openai/api_resources/abstract/createable_api_resource.py b/openai/api_resources/abstract/createable_api_resource.py index 39d3e4f504..1361c02627 100644 --- a/openai/api_resources/abstract/createable_api_resource.py +++ b/openai/api_resources/abstract/createable_api_resource.py @@ -7,15 +7,13 @@ class CreateableAPIResource(APIResource): plain_old_data = False @classmethod - def create( + def __prepare_create_requestor( cls, api_key=None, api_base=None, api_type=None, - request_id=None, api_version=None, organization=None, - **params, ): requestor = api_requestor.APIRequestor( api_key, @@ -35,6 +33,26 @@ def create( url = cls.class_url() else: raise error.InvalidAPIType("Unsupported API type %s" % api_type) + return requestor, url + + @classmethod + def create( + cls, + api_key=None, + api_base=None, + api_type=None, + request_id=None, + api_version=None, + organization=None, + **params, + ): + requestor, url = cls.__prepare_create_requestor( + api_key, + api_base, + api_type, + api_version, + organization, + ) response, _, api_key = requestor.request( "post", url, params, request_id=request_id @@ -47,3 +65,34 @@ def create( organization, plain_old_data=cls.plain_old_data, ) + + @classmethod + async def acreate( + cls, + api_key=None, + api_base=None, + api_type=None, + request_id=None, + api_version=None, + organization=None, + **params, + ): + requestor, url = cls.__prepare_create_requestor( + api_key, + api_base, + api_type, + api_version, + organization, + ) + + response, _, api_key = await requestor.arequest( + "post", url, params, request_id=request_id + ) + + return util.convert_to_openai_object( + response, + api_key, + api_version, + organization, + plain_old_data=cls.plain_old_data, + ) diff --git a/openai/api_resources/abstract/deletable_api_resource.py b/openai/api_resources/abstract/deletable_api_resource.py index 220375ca2f..a800ceb812 100644 --- a/openai/api_resources/abstract/deletable_api_resource.py +++ b/openai/api_resources/abstract/deletable_api_resource.py @@ -1,4 +1,5 @@ from urllib.parse import quote_plus +from typing import Awaitable from openai import error from openai.api_resources.abstract.api_resource import APIResource @@ -7,7 +8,7 @@ class DeletableAPIResource(APIResource): @classmethod - def delete(cls, sid, api_type=None, api_version=None, **params): + def __prepare_delete(cls, sid, api_type=None, api_version=None): if isinstance(cls, APIResource): raise ValueError(".delete may only be called as a class method now.") @@ -28,7 +29,20 @@ def delete(cls, sid, api_type=None, api_version=None, **params): url = "%s/%s" % (base, extn) else: raise error.InvalidAPIType("Unsupported API type %s" % api_type) + return url + + @classmethod + def delete(cls, sid, api_type=None, api_version=None, **params): + url = cls.__prepare_delete(sid, api_type, api_version) return cls._static_request( "delete", url, api_type=api_type, api_version=api_version, **params ) + + @classmethod + def adelete(cls, sid, api_type=None, api_version=None, **params) -> Awaitable: + url = cls.__prepare_delete(sid, api_type, api_version) + + return cls._astatic_request( + "delete", url, api_type=api_type, api_version=api_version, **params + ) diff --git a/openai/api_resources/abstract/engine_api_resource.py b/openai/api_resources/abstract/engine_api_resource.py index 152313c202..d6fe0d39a9 100644 --- a/openai/api_resources/abstract/engine_api_resource.py +++ b/openai/api_resources/abstract/engine_api_resource.py @@ -61,12 +61,11 @@ def class_url( raise error.InvalidAPIType("Unsupported API type %s" % api_type) @classmethod - def create( + def __prepare_create_request( cls, api_key=None, api_base=None, api_type=None, - request_id=None, api_version=None, organization=None, **params, @@ -112,6 +111,45 @@ def create( organization=organization, ) url = cls.class_url(engine, api_type, api_version) + return ( + deployment_id, + engine, + timeout, + stream, + headers, + request_timeout, + typed_api_type, + requestor, + url, + params, + ) + + @classmethod + def create( + cls, + api_key=None, + api_base=None, + api_type=None, + request_id=None, + api_version=None, + organization=None, + **params, + ): + ( + deployment_id, + engine, + timeout, + stream, + headers, + request_timeout, + typed_api_type, + requestor, + url, + params, + ) = cls.__prepare_create_request( + api_key, api_base, api_type, api_version, organization, **params + ) + response, _, api_key = requestor.request( "post", url, @@ -151,6 +189,70 @@ def create( return obj + @classmethod + async def acreate( + cls, + api_key=None, + api_base=None, + api_type=None, + request_id=None, + api_version=None, + organization=None, + **params, + ): + ( + deployment_id, + engine, + timeout, + stream, + headers, + request_timeout, + typed_api_type, + requestor, + url, + params, + ) = cls.__prepare_create_request( + api_key, api_base, api_type, api_version, organization, **params + ) + response, _, api_key = await requestor.arequest( + "post", + url, + params=params, + headers=headers, + stream=stream, + request_id=request_id, + request_timeout=request_timeout, + ) + + if stream: + # must be an iterator + assert not isinstance(response, OpenAIResponse) + return ( + util.convert_to_openai_object( + line, + api_key, + api_version, + organization, + engine=engine, + plain_old_data=cls.plain_old_data, + ) + for line in response + ) + else: + obj = util.convert_to_openai_object( + response, + api_key, + api_version, + organization, + engine=engine, + plain_old_data=cls.plain_old_data, + ) + + if timeout is not None: + await obj.await_(timeout=timeout or None) + + return obj + def instance_url(self): id = self.get("id") @@ -206,3 +308,18 @@ def wait(self, timeout=None): break self.refresh() return self + + async def await_(self, timeout=None): + """Async version of `EngineApiResource.wait`""" + start = time.time() + while self.status != "complete": + self.timeout = ( + min(timeout + start - time.time(), MAX_TIMEOUT) + if timeout is not None + else MAX_TIMEOUT + ) + if self.timeout < 0: + del self.timeout + break + await self.arefresh() + return self diff --git a/openai/api_resources/abstract/listable_api_resource.py b/openai/api_resources/abstract/listable_api_resource.py index adbf4e8df9..3e59979f13 100644 --- a/openai/api_resources/abstract/listable_api_resource.py +++ b/openai/api_resources/abstract/listable_api_resource.py @@ -9,15 +9,13 @@ def auto_paging_iter(cls, *args, **params): return cls.list(*args, **params).auto_paging_iter() @classmethod - def list( + def __prepare_list_requestor( cls, api_key=None, - request_id=None, api_version=None, organization=None, api_base=None, api_type=None, - **params, ): requestor = api_requestor.APIRequestor( api_key, @@ -38,6 +36,26 @@ def list( url = cls.class_url() else: raise error.InvalidAPIType("Unsupported API type %s" % api_type) + return requestor, url + + @classmethod + def list( + cls, + api_key=None, + request_id=None, + api_version=None, + organization=None, + api_base=None, + api_type=None, + **params, + ): + requestor, url = cls.__prepare_list_requestor( + api_key, + api_version, + organization, + api_base, + api_type, + ) response, _, api_key = requestor.request( "get", url, params, request_id=request_id @@ -47,3 +65,31 @@ def list( ) openai_object._retrieve_params = params return openai_object + + @classmethod + async def alist( + cls, + api_key=None, + request_id=None, + api_version=None, + organization=None, + api_base=None, + api_type=None, + **params, + ): + requestor, url = cls.__prepare_list_requestor( + api_key, + api_version, + organization, + api_base, + api_type, + ) + + response, _, api_key = await requestor.arequest( + "get", url, params, request_id=request_id + ) + openai_object = util.convert_to_openai_object( + response, api_key, api_version, organization + ) + openai_object._retrieve_params = params + return openai_object diff --git a/openai/api_resources/abstract/nested_resource_class_methods.py b/openai/api_resources/abstract/nested_resource_class_methods.py index c86e59fbf6..bfa5bcd873 100644 --- a/openai/api_resources/abstract/nested_resource_class_methods.py +++ b/openai/api_resources/abstract/nested_resource_class_methods.py @@ -3,8 +3,12 @@ from openai import api_requestor, util -def nested_resource_class_methods( - resource, path=None, operations=None, resource_plural=None +def _nested_resource_class_methods( + resource, + path=None, + operations=None, + resource_plural=None, + async_=False, ): if resource_plural is None: resource_plural = "%ss" % resource @@ -43,8 +47,34 @@ def nested_resource_request( response, api_key, api_version, organization ) + async def anested_resource_request( + cls, + method, + url, + api_key=None, + request_id=None, + api_version=None, + organization=None, + **params, + ): + requestor = api_requestor.APIRequestor( + api_key, api_version=api_version, organization=organization + ) + response, _, api_key = await requestor.arequest( + method, url, params, request_id=request_id + ) + return util.convert_to_openai_object( + response, api_key, api_version, organization + ) + resource_request_method = "%ss_request" % resource - setattr(cls, resource_request_method, classmethod(nested_resource_request)) + setattr( + cls, + resource_request_method, + classmethod( + anested_resource_request if async_ else nested_resource_request + ), + ) for operation in operations: if operation == "create": @@ -100,3 +130,25 @@ def list_nested_resources(cls, id, **params): return cls return wrapper + + +def nested_resource_class_methods( + resource, + path=None, + operations=None, + resource_plural=None, +): + return _nested_resource_class_methods( + resource, path, operations, resource_plural, async_=False + ) + + +def anested_resource_class_methods( + resource, + path=None, + operations=None, + resource_plural=None, +): + return _nested_resource_class_methods( + resource, path, operations, resource_plural, async_=True + ) diff --git a/openai/api_resources/abstract/updateable_api_resource.py b/openai/api_resources/abstract/updateable_api_resource.py index e7289d12d3..245f9b80b3 100644 --- a/openai/api_resources/abstract/updateable_api_resource.py +++ b/openai/api_resources/abstract/updateable_api_resource.py @@ -1,4 +1,5 @@ from urllib.parse import quote_plus +from typing import Awaitable from openai.api_resources.abstract.api_resource import APIResource @@ -8,3 +9,8 @@ class UpdateableAPIResource(APIResource): def modify(cls, sid, **params): url = "%s/%s" % (cls.class_url(), quote_plus(sid)) return cls._static_request("post", url, **params) + + @classmethod + def amodify(cls, sid, **params) -> Awaitable: + url = "%s/%s" % (cls.class_url(), quote_plus(sid)) + return cls._astatic_request("patch", url, **params) diff --git a/openai/api_resources/answer.py b/openai/api_resources/answer.py index 33de3cb7e9..be8c4f1ac8 100644 --- a/openai/api_resources/answer.py +++ b/openai/api_resources/answer.py @@ -10,3 +10,8 @@ def get_url(self): def create(cls, **params): instance = cls() return instance.request("post", cls.get_url(), params) + + @classmethod + def acreate(cls, **params): + instance = cls() + return instance.arequest("post", cls.get_url(), params) diff --git a/openai/api_resources/classification.py b/openai/api_resources/classification.py index 6423c6946a..823f521b96 100644 --- a/openai/api_resources/classification.py +++ b/openai/api_resources/classification.py @@ -10,3 +10,8 @@ def get_url(self): def create(cls, **params): instance = cls() return instance.request("post", cls.get_url(), params) + + @classmethod + def acreate(cls, **params): + instance = cls() + return instance.arequest("post", cls.get_url(), params) diff --git a/openai/api_resources/completion.py b/openai/api_resources/completion.py index 429597b46e..6912b4b730 100644 --- a/openai/api_resources/completion.py +++ b/openai/api_resources/completion.py @@ -28,3 +28,23 @@ def create(cls, *args, **kwargs): raise util.log_info("Waiting for model to warm up", error=e) + + @classmethod + async def acreate(cls, *args, **kwargs): + """ + Creates a new completion for the provided prompt and parameters. + + See https://beta.openai.com/docs/api-reference/completions/create for a list + of valid parameters. + """ + start = time.time() + timeout = kwargs.pop("timeout", None) + + while True: + try: + return await super().acreate(*args, **kwargs) + except TryAgain as e: + if timeout is not None and time.time() > start + timeout: + raise + + util.log_info("Waiting for model to warm up", error=e) diff --git a/openai/api_resources/customer.py b/openai/api_resources/customer.py index 571adf8eac..cb9779a2f1 100644 --- a/openai/api_resources/customer.py +++ b/openai/api_resources/customer.py @@ -10,3 +10,8 @@ def get_url(self, customer, endpoint): def create(cls, customer, endpoint, **params): instance = cls() return instance.request("post", cls.get_url(customer, endpoint), params) + + @classmethod + def acreate(cls, customer, endpoint, **params): + instance = cls() + return instance.arequest("post", cls.get_url(customer, endpoint), params) diff --git a/openai/api_resources/deployment.py b/openai/api_resources/deployment.py index 5850e0c9fb..2f3fcd1307 100644 --- a/openai/api_resources/deployment.py +++ b/openai/api_resources/deployment.py @@ -11,10 +11,7 @@ class Deployment(CreateableAPIResource, ListableAPIResource, DeletableAPIResourc OBJECT_NAME = "deployments" @classmethod - def create(cls, *args, **kwargs): - """ - Creates a new deployment for the provided prompt and parameters. - """ + def _check_create(cls, *args, **kwargs): typed_api_type, _ = cls._get_api_type_and_version( kwargs.get("api_type", None), None ) @@ -45,10 +42,24 @@ def create(cls, *args, **kwargs): param="scale_settings", ) + @classmethod + def create(cls, *args, **kwargs): + """ + Creates a new deployment for the provided prompt and parameters. + """ + cls._check_create(*args, **kwargs) return super().create(*args, **kwargs) @classmethod - def list(cls, *args, **kwargs): + def acreate(cls, *args, **kwargs): + """ + Creates a new deployment for the provided prompt and parameters. + """ + cls._check_create(*args, **kwargs) + return super().acreate(*args, **kwargs) + + @classmethod + def _check_list(cls, *args, **kwargs): typed_api_type, _ = cls._get_api_type_and_version( kwargs.get("api_type", None), None ) @@ -57,10 +68,18 @@ def list(cls, *args, **kwargs): "Deployment operations are only available for the Azure API type." ) + @classmethod + def list(cls, *args, **kwargs): + cls._check_list(*args, **kwargs) return super().list(*args, **kwargs) @classmethod - def delete(cls, *args, **kwargs): + def alist(cls, *args, **kwargs): + cls._check_list(*args, **kwargs) + return super().alist(*args, **kwargs) + + @classmethod + def _check_delete(cls, *args, **kwargs): typed_api_type, _ = cls._get_api_type_and_version( kwargs.get("api_type", None), None ) @@ -69,10 +88,18 @@ def delete(cls, *args, **kwargs): "Deployment operations are only available for the Azure API type." ) + @classmethod + def delete(cls, *args, **kwargs): + cls._check_delete(*args, **kwargs) return super().delete(*args, **kwargs) @classmethod - def retrieve(cls, *args, **kwargs): + def adelete(cls, *args, **kwargs): + cls._check_delete(*args, **kwargs) + return super().adelete(*args, **kwargs) + + @classmethod + def _check_retrieve(cls, *args, **kwargs): typed_api_type, _ = cls._get_api_type_and_version( kwargs.get("api_type", None), None ) @@ -81,4 +108,12 @@ def retrieve(cls, *args, **kwargs): "Deployment operations are only available for the Azure API type." ) + @classmethod + def retrieve(cls, *args, **kwargs): + cls._check_retrieve(*args, **kwargs) return super().retrieve(*args, **kwargs) + + @classmethod + def aretrieve(cls, *args, **kwargs): + cls._check_retrieve(*args, **kwargs) + return super().aretrieve(*args, **kwargs) diff --git a/openai/api_resources/edit.py b/openai/api_resources/edit.py index fe66b6f0f4..985f062ddb 100644 --- a/openai/api_resources/edit.py +++ b/openai/api_resources/edit.py @@ -31,3 +31,27 @@ def create(cls, *args, **kwargs): raise util.log_info("Waiting for model to warm up", error=e) + + @classmethod + async def acreate(cls, *args, **kwargs): + """ + Creates a new edit for the provided input, instruction, and parameters. + """ + start = time.time() + timeout = kwargs.pop("timeout", None) + + api_type = kwargs.pop("api_type", None) + typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0] + if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + raise error.InvalidAPIType( + "This operation is not supported by the Azure OpenAI API yet." + ) + + while True: + try: + return await super().acreate(*args, **kwargs) + except TryAgain as e: + if timeout is not None and time.time() > start + timeout: + raise + + util.log_info("Waiting for model to warm up", error=e) diff --git a/openai/api_resources/embedding.py b/openai/api_resources/embedding.py index 85ede2c088..679f97973b 100644 --- a/openai/api_resources/embedding.py +++ b/openai/api_resources/embedding.py @@ -50,3 +50,42 @@ def create(cls, *args, **kwargs): raise util.log_info("Waiting for model to warm up", error=e) + + @classmethod + async def acreate(cls, *args, **kwargs): + """ + Creates a new embedding for the provided input and parameters. + + See https://beta.openai.com/docs/api-reference/embeddings for a list + of valid parameters. + """ + start = time.time() + timeout = kwargs.pop("timeout", None) + + user_provided_encoding_format = kwargs.get("encoding_format", None) + + # If encoding format was not explicitly specified, we opaquely use base64 for performance + if not user_provided_encoding_format: + kwargs["encoding_format"] = "base64" + + while True: + try: + response = await super().acreate(*args, **kwargs) + + # If a user specifies base64, we'll just return the encoded string. + # This is only for the default case. + if not user_provided_encoding_format: + for data in response.data: + + # If an engine isn't using this optimization, don't do anything + if type(data["embedding"]) == str: + data["embedding"] = np.frombuffer( + base64.b64decode(data["embedding"]), dtype="float32" + ).tolist() + + return response + except TryAgain as e: + if timeout is not None and time.time() > start + timeout: + raise + + util.log_info("Waiting for model to warm up", error=e) diff --git a/openai/api_resources/engine.py b/openai/api_resources/engine.py index 11c8ec9ec9..93140819a9 100644 --- a/openai/api_resources/engine.py +++ b/openai/api_resources/engine.py @@ -27,6 +27,23 @@ def generate(self, timeout=None, **params): util.log_info("Waiting for model to warm up", error=e) + async def agenerate(self, timeout=None, **params): + start = time.time() + while True: + try: + return await self.arequest( + "post", + self.instance_url() + "/generate", + params, + stream=params.get("stream"), + plain_old_data=True, + ) + except TryAgain as e: + if timeout is not None and time.time() > start + timeout: + raise + + util.log_info("Waiting for model to warm up", error=e) + def search(self, **params): if self.typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): return self.request("post", self.instance_url("search"), params) @@ -35,6 +52,14 @@ def search(self, **params): else: raise InvalidAPIType("Unsupported API type %s" % self.api_type) + def asearch(self, **params): + if self.typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): + return self.arequest("post", self.instance_url("search"), params) + elif self.typed_api_type == ApiType.OPEN_AI: + return self.arequest("post", self.instance_url() + "/search", params) + else: + raise InvalidAPIType("Unsupported API type %s" % self.api_type) + def embeddings(self, **params): warnings.warn( "Engine.embeddings is deprecated, use Embedding.create", DeprecationWarning diff --git a/openai/api_resources/file.py b/openai/api_resources/file.py index aba7117fea..3654dd2d2e 100644 --- a/openai/api_resources/file.py +++ b/openai/api_resources/file.py @@ -12,7 +12,7 @@ class File(ListableAPIResource, DeletableAPIResource): OBJECT_NAME = "files" @classmethod - def create( + def __prepare_file_create( cls, file, purpose, @@ -56,13 +56,69 @@ def create( ) else: files.append(("file", ("file", file, "application/octet-stream"))) + + return requestor, url, files + + @classmethod + def create( + cls, + file, + purpose, + model=None, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + user_provided_filename=None, + ): + requestor, url, files = cls.__prepare_file_create( + file, + purpose, + model, + api_key, + api_base, + api_type, + api_version, + organization, + user_provided_filename, + ) response, _, api_key = requestor.request("post", url, files=files) return util.convert_to_openai_object( response, api_key, api_version, organization ) @classmethod - def download( + async def acreate( + cls, + file, + purpose, + model=None, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + user_provided_filename=None, + ): + requestor, url, files = cls.__prepare_file_create( + file, + purpose, + model, + api_key, + api_base, + api_type, + api_version, + organization, + user_provided_filename, + ) + response, _, api_key = await requestor.arequest("post", url, files=files) + return util.convert_to_openai_object( + response, api_key, api_version, organization + ) + + @classmethod + def __prepare_file_download( cls, id, api_key=None, @@ -84,17 +140,33 @@ def download( if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): base = cls.class_url() - url = "/%s%s/%s/content?api-version=%s" % ( + url = "/%s%s/%s?api-version=%s" % ( cls.azure_api_prefix, base, id, api_version, ) elif typed_api_type == ApiType.OPEN_AI: - url = f"{cls.class_url()}/{id}/content" + url = "%s/%s" % (cls.class_url(), id) else: raise error.InvalidAPIType("Unsupported API type %s" % api_type) + return requestor, url + + @classmethod + def download( + cls, + id, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + ): + requestor, url = cls.__prepare_file_download( + id, api_key, api_base, api_type, api_version, organization + ) + result = requestor.request_raw("get", url) if not 200 <= result.status_code < 300: raise requestor.handle_error_response( @@ -107,25 +179,32 @@ def download( return result.content @classmethod - def find_matching_files( + async def adownload( cls, - name, - bytes, - purpose, + id, api_key=None, api_base=None, api_type=None, api_version=None, organization=None, ): - """Find already uploaded files with the same name, size, and purpose.""" - all_files = cls.list( - api_key=api_key, - api_base=api_base or openai.api_base, - api_type=api_type, - api_version=api_version, - organization=organization, - ).get("data", []) + requestor, url = cls.__prepare_file_download( + id, api_key, api_base, api_type, api_version, organization + ) + + result = await requestor.arequest_raw("get", url) + if not 200 <= result.status < 300: + raise requestor.handle_error_response( + result.content, + result.status, + json.loads(cast(bytes, result.content)), + result.headers, + stream_error=False, + ) + return result.content + + @classmethod + def __find_matching_files(cls, name, all_files, purpose): matching_files = [] basename = os.path.basename(name) for f in all_files: @@ -140,3 +219,49 @@ def find_matching_files( continue matching_files.append(f) return matching_files + + @classmethod + def find_matching_files( + cls, + name, + bytes, + purpose, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + ): + """Find already uploaded files with the same name, size, and purpose.""" + all_files = cls.list( + api_key=api_key, + api_base=api_base or openai.api_base, + api_type=api_type, + api_version=api_version, + organization=organization, + ).get("data", []) + return cls.__find_matching_files(name, all_files, purpose) + + @classmethod + async def afind_matching_files( + cls, + name, + bytes, + purpose, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + ): + """Find already uploaded files with the same name, size, and purpose.""" + all_files = ( + await cls.alist( + api_key=api_key, + api_base=api_base or openai.api_base, + api_type=api_type, + api_version=api_version, + organization=organization, + ) + ).get("data", []) + return cls.__find_matching_files(name, all_files, purpose) diff --git a/openai/api_resources/fine_tune.py b/openai/api_resources/fine_tune.py index 1b5d92d861..65dba48836 100644 --- a/openai/api_resources/fine_tune.py +++ b/openai/api_resources/fine_tune.py @@ -16,7 +16,7 @@ class FineTune(ListableAPIResource, CreateableAPIResource, DeletableAPIResource) OBJECT_NAME = "fine-tunes" @classmethod - def cancel( + def _prepare_cancel( cls, id, api_key=None, @@ -44,10 +44,50 @@ def cancel( raise error.InvalidAPIType("Unsupported API type %s" % api_type) instance = cls(id, api_key, **params) + return instance, url + + @classmethod + def cancel( + cls, + id, + api_key=None, + api_type=None, + request_id=None, + api_version=None, + **params, + ): + instance, url = cls._prepare_cancel( + id, + api_key, + api_type, + request_id, + api_version, + **params, + ) return instance.request("post", url, request_id=request_id) @classmethod - def stream_events( + def acancel( + cls, + id, + api_key=None, + api_type=None, + request_id=None, + api_version=None, + **params, + ): + instance, url = cls._prepare_cancel( + id, + api_key, + api_type, + request_id, + api_version, + **params, + ) + return instance.arequest("post", url, request_id=request_id) + + @classmethod + def _prepare_stream_events( cls, id, api_key=None, @@ -85,7 +125,32 @@ def stream_events( else: raise error.InvalidAPIType("Unsupported API type %s" % api_type) - response, _, api_key = requestor.request( + return requestor, url + + @classmethod + async def stream_events( + cls, + id, + api_key=None, + api_base=None, + api_type=None, + request_id=None, + api_version=None, + organization=None, + **params, + ): + requestor, url = cls._prepare_stream_events( + id, + api_key, + api_base, + api_type, + request_id, + api_version, + organization, + **params, + ) + + response, _, api_key = await requestor.arequest( "get", url, params, stream=True, request_id=request_id ) diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py index ebb77676df..089200015d 100644 --- a/openai/api_resources/image.py +++ b/openai/api_resources/image.py @@ -22,7 +22,12 @@ def create( return instance.request("post", cls._get_url("generations"), params) @classmethod - def create_variation( + def acreate(cls, **params): + instance = cls() + return instance.arequest("post", cls._get_url("generations"), params) + + @classmethod + def _prepare_create_variation( cls, image, api_key=None, @@ -47,6 +52,28 @@ def create_variation( for key, value in params.items(): files.append((key, (None, value))) files.append(("image", ("image", image, "application/octet-stream"))) + return requestor, url, files + + @classmethod + def create_variation( + cls, + image, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + requestor, url, files = cls._prepare_create_variation( + image, + api_key, + api_base, + api_type, + api_version, + organization, + **params, + ) response, _, api_key = requestor.request("post", url, files=files) @@ -55,7 +82,34 @@ def create_variation( ) @classmethod - def create_edit( + async def acreate_variation( + cls, + image, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + requestor, url, files = cls._prepare_create_variation( + image, + api_key, + api_base, + api_type, + api_version, + organization, + **params, + ) + + response, _, api_key = await requestor.arequest("post", url, files=files) + + return util.convert_to_openai_object( + response, api_key, api_version, organization + ) + + @classmethod + def _prepare_create_edit( cls, image, mask, @@ -82,9 +136,62 @@ def create_edit( files.append((key, (None, value))) files.append(("image", ("image", image, "application/octet-stream"))) files.append(("mask", ("mask", mask, "application/octet-stream"))) + return requestor, url, files + + @classmethod + def create_edit( + cls, + image, + mask, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + requestor, url, files = cls._prepare_create_edit( + image, + mask, + api_key, + api_base, + api_type, + api_version, + organization, + **params, + ) response, _, api_key = requestor.request("post", url, files=files) return util.convert_to_openai_object( response, api_key, api_version, organization ) + + @classmethod + async def acreate_edit( + cls, + image, + mask, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + requestor, url, files = cls._prepare_create_edit( + image, + mask, + api_key, + api_base, + api_type, + api_version, + organization, + **params, + ) + + response, _, api_key = await requestor.arequest("post", url, files=files) + + return util.convert_to_openai_object( + response, api_key, api_version, organization + ) diff --git a/openai/api_resources/moderation.py b/openai/api_resources/moderation.py index 52f997fb26..4b8b58c6d9 100644 --- a/openai/api_resources/moderation.py +++ b/openai/api_resources/moderation.py @@ -11,7 +11,7 @@ def get_url(self): return "/moderations" @classmethod - def create(cls, input: Union[str, List[str]], model: Optional[str] = None, api_key: Optional[str] = None): + def _prepare_create(cls, input, model, api_key): if model is not None and model not in cls.VALID_MODEL_NAMES: raise ValueError( f"The parameter model should be chosen from {cls.VALID_MODEL_NAMES} " @@ -22,4 +22,24 @@ def create(cls, input: Union[str, List[str]], model: Optional[str] = None, api_k params = {"input": input} if model is not None: params["model"] = model + return instance, params + + @classmethod + def create( + cls, + input: Union[str, List[str]], + model: Optional[str] = None, + api_key: Optional[str] = None, + ): + instance, params = cls._prepare_create(input, model, api_key) return instance.request("post", cls.get_url(), params) + + @classmethod + def acreate( + cls, + input: Union[str, List[str]], + model: Optional[str] = None, + api_key: Optional[str] = None, + ): + instance, params = cls._prepare_create(input, model, api_key) + return instance.arequest("post", cls.get_url(), params) diff --git a/openai/api_resources/search.py b/openai/api_resources/search.py index adc113c1c4..0f9cdab604 100644 --- a/openai/api_resources/search.py +++ b/openai/api_resources/search.py @@ -28,3 +28,24 @@ def create(cls, *args, **kwargs): raise util.log_info("Waiting for model to warm up", error=e) + + @classmethod + async def acreate(cls, *args, **kwargs): + """ + Creates a new search for the provided input and parameters. + + See https://beta.openai.com/docs/api-reference/search for a list + of valid parameters. + """ + + start = time.time() + timeout = kwargs.pop("timeout", None) + + while True: + try: + return await super().acreate(*args, **kwargs) + except TryAgain as e: + if timeout is not None and time.time() > start + timeout: + raise + + util.log_info("Waiting for model to warm up", error=e) diff --git a/openai/embeddings_utils.py b/openai/embeddings_utils.py index 47a04e6582..c4e8a2f448 100644 --- a/openai/embeddings_utils.py +++ b/openai/embeddings_utils.py @@ -23,6 +23,19 @@ def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"] +@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) +async def aget_embedding( + text: str, engine="text-similarity-davinci-001" +) -> List[float]: + + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + + return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][ + "embedding" + ] + + @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) def get_embeddings( list_of_text: List[str], engine="text-similarity-babbage-001" @@ -37,6 +50,20 @@ def get_embeddings( return [d["embedding"] for d in data] +@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) +async def aget_embeddings( + list_of_text: List[str], engine="text-similarity-babbage-001" +) -> List[List[float]]: + assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." + + # replace newlines, which can negatively affect performance. + list_of_text = [text.replace("\n", " ") for text in list_of_text] + + data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data + data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. + return [d["embedding"] for d in data] + + def cosine_similarity(a, b): return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) diff --git a/openai/openai_object.py b/openai/openai_object.py index 5bfa29e45f..c0af6bbc2a 100644 --- a/openai/openai_object.py +++ b/openai/openai_object.py @@ -207,6 +207,57 @@ def request( plain_old_data=plain_old_data, ) + async def arequest( + self, + method, + url, + params=None, + headers=None, + stream=False, + plain_old_data=False, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ): + if params is None: + params = self._retrieve_params + requestor = api_requestor.APIRequestor( + key=self.api_key, + api_base=self.api_base_override or self.api_base(), + api_type=self.api_type, + api_version=self.api_version, + organization=self.organization, + ) + response, stream, api_key = await requestor.arequest( + method, + url, + params=params, + stream=stream, + headers=headers, + request_id=request_id, + request_timeout=request_timeout, + ) + + if stream: + assert not isinstance(response, OpenAIResponse) # must be an iterator + return ( + util.convert_to_openai_object( + line, + api_key, + self.api_version, + self.organization, + plain_old_data=plain_old_data, + ) + for line in response + ) + else: + return util.convert_to_openai_object( + response, + api_key, + self.api_version, + self.organization, + plain_old_data=plain_old_data, + ) + def __repr__(self): ident_parts = [type(self).__name__] diff --git a/openai/tests/__init__.py b/openai/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openai/tests/asyncio/__init__.py b/openai/tests/asyncio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openai/tests/asyncio/test_endpoints.py b/openai/tests/asyncio/test_endpoints.py new file mode 100644 index 0000000000..e5c6d012cd --- /dev/null +++ b/openai/tests/asyncio/test_endpoints.py @@ -0,0 +1,65 @@ +import io +import json + +import pytest + +import openai +from openai import error + + +pytestmark = [pytest.mark.asyncio] + + +# FILE TESTS +async def test_file_upload(): + result = await openai.File.acreate( + file=io.StringIO(json.dumps({"text": "test file data"})), + purpose="search", + ) + assert result.purpose == "search" + assert "id" in result + + result = await openai.File.aretrieve(id=result.id) + assert result.status == "uploaded" + + +# COMPLETION TESTS +async def test_completions(): + result = await openai.Completion.acreate( + prompt="This was a test", n=5, engine="ada" + ) + assert len(result.choices) == 5 + + +async def test_completions_multiple_prompts(): + result = await openai.Completion.acreate( + prompt=["This was a test", "This was another test"], n=5, engine="ada" + ) + assert len(result.choices) == 10 + + +async def test_completions_model(): + result = await openai.Completion.acreate(prompt="This was a test", n=5, model="ada") + assert len(result.choices) == 5 + assert result.model.startswith("ada") + + +async def test_timeout_raises_error(): + # A query that should take awhile to return + with pytest.raises(error.Timeout): + await openai.Completion.acreate( + prompt="test" * 1000, + n=10, + model="ada", + max_tokens=100, + request_timeout=0.01, + ) + + +async def test_timeout_does_not_error(): + # A query that should be fast + await openai.Completion.acreate( + prompt="test", + model="ada", + request_timeout=10, + ) diff --git a/openai/tests/test_long_examples_validator.py b/openai/tests/test_long_examples_validator.py index 7f3e4c8cf1..6346b25a02 100644 --- a/openai/tests/test_long_examples_validator.py +++ b/openai/tests/test_long_examples_validator.py @@ -20,29 +20,28 @@ def test_long_examples_validator() -> None: # the order of these matters unprepared_training_data = [ {"prompt": long_prompt, "completion": long_completion}, # 1 of 2 duplicates - {"prompt": short_prompt, "completion": short_completion}, + {"prompt": short_prompt, "completion": short_completion}, {"prompt": long_prompt, "completion": long_completion}, # 2 of 2 duplicates - ] with NamedTemporaryFile(suffix="jsonl", mode="w") as training_data: for prompt_completion_row in unprepared_training_data: training_data.write(json.dumps(prompt_completion_row) + "\n") training_data.flush() - + prepared_data_cmd_output = subprocess.run( - [f"openai tools fine_tunes.prepare_data -f {training_data.name}"], - stdout=subprocess.PIPE, - text=True, + [f"openai tools fine_tunes.prepare_data -f {training_data.name}"], + stdout=subprocess.PIPE, + text=True, input="y\ny\ny\ny\ny", # apply all recommendations, one at a time stderr=subprocess.PIPE, encoding="utf-8", - shell=True + shell=True, ) # validate data was prepared successfully - assert prepared_data_cmd_output.stderr == "" + assert prepared_data_cmd_output.stderr == "" # validate get_long_indexes() applied during optional_fn() call in long_examples_validator() assert "indices of the long examples has changed" in prepared_data_cmd_output.stdout - - return prepared_data_cmd_output.stdout \ No newline at end of file + + return prepared_data_cmd_output.stdout diff --git a/setup.py b/setup.py index 9b318d326e..aa112f7931 100644 --- a/setup.py +++ b/setup.py @@ -26,9 +26,10 @@ "openpyxl>=3.0.7", # Needed for CLI fine-tuning data preparation tool xlsx format "numpy", 'typing_extensions;python_version<"3.8"', # Needed for type hints for mypy + "aiohttp", # Needed for async support ], extras_require={ - "dev": ["black~=21.6b0", "pytest==6.*"], + "dev": ["black~=21.6b0", "pytest==6.*", "pytest-asyncio", "pytest-mock"], "wandb": ["wandb"], "embeddings": [ "scikit-learn>=1.0.2", # Needed for embedding utils, versions >= 1.1 require python 3.8