From e26e7f7cd237b696255d0e663513198647f8c2ae Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 5 Nov 2024 19:01:17 -0800 Subject: [PATCH 01/28] First WIP prototype of async mode, refs #507 --- llm/__init__.py | 6 + llm/cli.py | 56 ++++-- llm/default_plugins/openai_models.py | 60 ++++++- llm/models.py | 249 ++++++++++++++++++++++++--- 4 files changed, 328 insertions(+), 43 deletions(-) diff --git a/llm/__init__.py b/llm/__init__.py index 0ea6c242..de838418 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -4,6 +4,7 @@ NeedsKeyException, ) from .models import ( + AsyncModel, Attachment, Conversation, Model, @@ -26,6 +27,7 @@ __all__ = [ "hookimpl", + "get_async_model", "get_model", "get_key", "user_dir", @@ -143,6 +145,10 @@ def get_model_aliases() -> Dict[str, Model]: return model_aliases +def get_async_model(model_id: str) -> AsyncModel: + return get_model(model_id).get_async_model() + + class UnknownModelError(KeyError): pass diff --git a/llm/cli.py b/llm/cli.py index 941831c5..90eb78c5 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1,3 +1,4 @@ +import asyncio import click from click_default_group import DefaultGroup from dataclasses import asdict @@ -11,6 +12,7 @@ Template, UnknownModelError, encode, + get_async_model, get_default_model, get_default_embedding_model, get_embedding_models_with_aliases, @@ -193,6 +195,7 @@ def cli(): ) @click.option("--key", help="API key to use") @click.option("--save", help="Save prompt with this template name") +@click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously") def prompt( prompt, system, @@ -209,6 +212,7 @@ def prompt( conversation_id, key, save, + async_, ): """ Execute a prompt @@ -325,7 +329,10 @@ def read_prompt(): # Now resolve the model try: - model = model_aliases[model_id] + if async_: + model = get_async_model(model_id) + else: + model = get_model(model_id) except KeyError: raise click.ClickException("'{}' is not a known model".format(model_id)) @@ -363,21 +370,48 @@ def read_prompt(): prompt_method = conversation.prompt try: - response = prompt_method( - prompt, attachments=resolved_attachments, system=system, **validated_options - ) - if should_stream: - for chunk in response: - print(chunk, end="") - sys.stdout.flush() - print("") + if async_: + + async def inner(): + if should_stream: + async for chunk in prompt_method( + prompt, + attachments=resolved_attachments, + system=system, + **validated_options, + ): + print(chunk, end="") + sys.stdout.flush() + print("") + else: + response = await prompt_method( + prompt, + attachments=resolved_attachments, + system=system, + **validated_options, + ) + print(response.text()) + + asyncio.run(inner()) else: - print(response.text()) + response = prompt_method( + prompt, + attachments=resolved_attachments, + system=system, + **validated_options, + ) + if should_stream: + for chunk in response: + print(chunk, end="") + sys.stdout.flush() + print("") + else: + print(response.text()) except Exception as ex: raise click.ClickException(str(ex)) # Log to the database - if (logs_on() or log) and not no_log: + if (logs_on() or log) and not no_log and not async_: log_path = logs_db_path() (log_path.parent).mkdir(parents=True, exist_ok=True) db = sqlite_utils.Database(log_path) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 5cbb02bb..777bd346 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -1,4 +1,4 @@ -from llm import EmbeddingModel, Model, hookimpl +from llm import AsyncModel, EmbeddingModel, Model, hookimpl import llm from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client import click @@ -254,6 +254,9 @@ class Chat(Model): default_max_tokens = None + def get_async_model(self): + return AsyncChat(self.model_id, self.key) + class Options(SharedOptions): json_object: Optional[bool] = Field( description="Output a valid JSON object {...}. Prompt must mention JSON.", @@ -297,10 +300,8 @@ def __init__( def __str__(self): return "OpenAI Chat: {}".format(self.model_id) - def execute(self, prompt, stream, response, conversation=None): + def build_messages(self, prompt, conversation): messages = [] - if prompt.system and not self.allows_system_prompt: - raise NotImplementedError("Model does not support system prompts") current_system = None if conversation is not None: for prev_response in conversation.responses: @@ -345,7 +346,12 @@ def execute(self, prompt, stream, response, conversation=None): {"type": "image_url", "image_url": {"url": url}} ) messages.append({"role": "user", "content": attachment_message}) + return messages + def execute(self, prompt, stream, response, conversation=None): + if prompt.system and not self.allows_system_prompt: + raise NotImplementedError("Model does not support system prompts") + messages = self.build_messages(prompt, conversation) kwargs = self.build_kwargs(prompt, stream) client = self.get_client() if stream: @@ -376,7 +382,7 @@ def execute(self, prompt, stream, response, conversation=None): yield completion.choices[0].message.content response._prompt_json = redact_data_urls({"messages": messages}) - def get_client(self): + def get_client(self, async_=False): kwargs = {} if self.api_base: kwargs["base_url"] = self.api_base @@ -396,7 +402,10 @@ def get_client(self): kwargs["default_headers"] = self.headers if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"): kwargs["http_client"] = logging_client() - return openai.OpenAI(**kwargs) + if async_: + return openai.AsyncOpenAI(**kwargs) + else: + return openai.OpenAI(**kwargs) def build_kwargs(self, prompt, stream): kwargs = dict(not_nulls(prompt.options)) @@ -410,6 +419,45 @@ def build_kwargs(self, prompt, stream): return kwargs +class AsyncChat(AsyncModel, Chat): + needs_key = "openai" + key_env_var = "OPENAI_API_KEY" + + async def execute(self, prompt, stream, response, conversation=None): + if prompt.system and not self.allows_system_prompt: + raise NotImplementedError("Model does not support system prompts") + messages = self.build_messages(prompt, conversation) + kwargs = self.build_kwargs(prompt, stream) + client = self.get_client(async_=True) + if stream: + completion = await client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=True, + **kwargs, + ) + chunks = [] + async for chunk in completion: + chunks.append(chunk) + try: + content = chunk.choices[0].delta.content + except IndexError: + content = None + if content is not None: + yield content + response.response_json = remove_dict_none_values(combine_chunks(chunks)) + else: + completion = await client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=False, + **kwargs, + ) + response.response_json = remove_dict_none_values(completion.model_dump()) + yield completion.choices[0].message.content + response._prompt_json = redact_data_urls({"messages": messages}) + + class Completion(Chat): class Options(SharedOptions): logprobs: Optional[int] = Field( diff --git a/llm/models.py b/llm/models.py index 838e25b1..d41b17e9 100644 --- a/llm/models.py +++ b/llm/models.py @@ -8,7 +8,19 @@ import puremagic import re import time -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union +from typing import ( + Any, + AsyncIterator, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Set, + TypeVar, + Union, +) from abc import ABC, abstractmethod import json from pydantic import BaseModel @@ -144,13 +156,19 @@ def from_row(cls, row): ) -class Response(ABC): +ModelT = TypeVar("ModelT", bound=Union["Model", "AsyncModel"]) +ConversationT = TypeVar( + "ConversationT", bound=Optional[Union["Conversation", "AsyncConversation"]] +) + + +class _BaseResponse(ABC, Generic[ModelT, ConversationT]): def __init__( self, prompt: Prompt, - model: "Model", + model: ModelT, stream: bool, - conversation: Optional[Conversation] = None, + conversation: ConversationT = None, ): self.prompt = prompt self._prompt_json = None @@ -161,28 +179,9 @@ def __init__( self.response_json = None self.conversation = conversation self.attachments: List[Attachment] = [] - - def __iter__(self) -> Iterator[str]: - self._start = time.monotonic() - self._start_utcnow = datetime.datetime.utcnow() - if self._done: - yield from self._chunks - for chunk in self.model.execute( - self.prompt, - stream=self.stream, - response=self, - conversation=self.conversation, - ): - yield chunk - self._chunks.append(chunk) - if self.conversation: - self.conversation.responses.append(self) - self._end = time.monotonic() - self._done = True - - def _force(self): - if not self._done: - list(self) + self._start: Optional[float] = None + self._end: Optional[float] = None + self._start_utcnow: Optional[datetime.datetime] = None def __str__(self) -> str: return self.text() @@ -203,6 +202,30 @@ def datetime_utc(self) -> str: self._force() return self._start_utcnow.isoformat() + +class Response(_BaseResponse["Model", Optional["Conversation"]]): + def _force(self): + if not self._done: + list(self) + + def __iter__(self) -> Iterator[str]: + self._start = time.monotonic() + self._start_utcnow = datetime.datetime.utcnow() + if self._done: + yield from self._chunks + for chunk in self.model.execute( + self.prompt, + stream=self.stream, + response=self, + conversation=self.conversation, + ): + yield chunk + self._chunks.append(chunk) + if self.conversation: + self.conversation.responses.append(self) + self._end = time.monotonic() + self._done = True + def log_to_db(self, db): conversation = self.conversation if not conversation: @@ -257,6 +280,51 @@ def log_to_db(self, db): }, ) + +class AsyncResponse(_BaseResponse["AsyncModel", Optional["AsyncConversation"]]): + async def _force(self): + if not self._done: + async for _ in self: + pass + + async def __aiter__(self) -> AsyncIterator[str]: + self._start = time.monotonic() + self._start_utcnow = datetime.datetime.utcnow() + if self._done: + for chunk in self._chunks: + yield chunk + return + + async for chunk in self.model.execute( + self.prompt, + stream=self.stream, + response=self, + conversation=self.conversation, + ): + yield chunk + self._chunks.append(chunk) + if self.conversation: + self.conversation.responses.append(self) + self._end = time.monotonic() + self._done = True + + # Override base methods to make them async + async def text(self) -> str: + await self._force() + return "".join(self._chunks) + + async def json(self) -> Optional[Dict[str, Any]]: + await self._force() + return self.response_json + + async def duration_ms(self) -> int: + await self._force() + return int((self._end - self._start) * 1000) + + async def datetime_utc(self) -> str: + await self._force() + return self._start_utcnow.isoformat() + @classmethod def fake( cls, @@ -362,6 +430,135 @@ def get_key(self): raise NeedsKeyException(message) +ResponseT = TypeVar("ResponseT") +ConversationT = TypeVar("ConversationT") + + +class _BaseModel(ABC, _get_key_mixin, Generic[ResponseT, ConversationT]): + model_id: str + + # API key handling + key: Optional[str] = None + needs_key: Optional[str] = None + key_env_var: Optional[str] = None + + # Model characteristics + can_stream: bool = False + attachment_types: Set = set() + + class Options(_Options): + pass + + def _validate_attachments( + self, attachments: Optional[List[Attachment]] = None + ) -> None: + """Shared attachment validation logic""" + if attachments and not self.attachment_types: + raise ValueError( + "This model does not support attachments, but some were provided" + ) + for attachment in attachments or []: + attachment_type = attachment.resolve_type() + if attachment_type not in self.attachment_types: + raise ValueError( + "This model does not support attachments of type '{}', only {}".format( + attachment_type, ", ".join(self.attachment_types) + ) + ) + + def __str__(self) -> str: + return "{}: {}".format(self.__class__.__name__, self.model_id) + + def __repr__(self): + return "<{} '{}'>".format(self.__class__.__name__, self.model_id) + + +class Model(_BaseModel["Response", "Conversation"]): + def conversation(self) -> "Conversation": + return Conversation(model=self) + + @abstractmethod + def execute( + self, + prompt: Prompt, + stream: bool, + response: "Response", + conversation: Optional["Conversation"], + ) -> Iterator[str]: + """ + Execute a prompt and yield chunks of text, or yield a single big chunk. + Any additional useful information about the execution should be assigned to the response. + """ + pass + + def prompt( + self, + prompt: str, + *, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + stream: bool = True, + **options + ) -> "Response": + self._validate_attachments(attachments) + return self.response( + Prompt( + prompt, + attachments=attachments, + system=system, + model=self, + options=self.Options(**options), + ), + stream=stream, + ) + + def response(self, prompt: Prompt, stream: bool = True) -> "Response": + return Response(prompt, self, stream) + + +class AsyncModel(_BaseModel["AsyncResponse", "AsyncConversation"]): + def conversation(self) -> "AsyncConversation": + return AsyncConversation(model=self) + + @abstractmethod + async def execute( + self, + prompt: Prompt, + stream: bool, + response: "AsyncResponse", + conversation: Optional["AsyncConversation"], + ) -> AsyncIterator[str]: + """ + Execute a prompt and yield chunks of text, or yield a single big chunk. + Any additional useful information about the execution should be assigned to the response. + """ + pass + + def prompt( + self, + prompt: str, + *, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + stream: bool = True, + **options + ) -> "AsyncResponse": + self._validate_attachments(attachments) + return self.response( + Prompt( + prompt, + attachments=attachments, + system=system, + model=self, + options=self.Options(**options), + ), + stream=stream, + ) + + def response(self, prompt: Prompt, stream: bool = True) -> "AsyncResponse": + return AsyncResponse(prompt, self, stream) + + class Model(ABC, _get_key_mixin): model_id: str From 1d8c3f85a4f29d10acd5658094d257fc39ddcf13 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 5 Nov 2024 19:05:09 -0800 Subject: [PATCH 02/28] Fix for llm hi --async --no-stream, refs #507 --- llm/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index 90eb78c5..762ef8f5 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -384,13 +384,13 @@ async def inner(): sys.stdout.flush() print("") else: - response = await prompt_method( + response = prompt_method( prompt, attachments=resolved_attachments, system=system, **validated_options, ) - print(response.text()) + print(await response.text()) asyncio.run(inner()) else: From b27b275cdc5e33edea2123e647a1fe56fa048ebb Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 5 Nov 2024 19:11:33 -0800 Subject: [PATCH 03/28] Fix for coroutine in __repr__ Refs https://github.com/simonw/llm/issues/507#issuecomment-2458639308 --- llm/models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llm/models.py b/llm/models.py index d41b17e9..ec83b47b 100644 --- a/llm/models.py +++ b/llm/models.py @@ -387,8 +387,11 @@ def from_row(cls, db, row): return response def __repr__(self): + text = '... not yet awaited ...' + if self._done: + text = "".join(self._chunks) return "".format( - self.prompt.prompt, self.text() + self.prompt.prompt, text ) From 44e6be188af98aeaa9448ad50674f746098c7914 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 5 Nov 2024 19:47:32 -0800 Subject: [PATCH 04/28] register_model is now async aware Refs https://github.com/simonw/llm/issues/507#issuecomment-2458658134 --- llm/__init__.py | 56 +++++++++++++++++++++++----- llm/cli.py | 4 +- llm/default_plugins/openai_models.py | 46 +++++++++++++++++------ llm/models.py | 7 ++-- 4 files changed, 85 insertions(+), 28 deletions(-) diff --git a/llm/__init__.py b/llm/__init__.py index de838418..18ea3055 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -76,11 +76,11 @@ def get_models_with_aliases() -> List["ModelWithAliases"]: for alias, model_id in configured_aliases.items(): extra_model_aliases.setdefault(model_id, []).append(alias) - def register(model, aliases=None): + def register(model, async_model=None, aliases=None): alias_list = list(aliases or []) if model.model_id in extra_model_aliases: alias_list.extend(extra_model_aliases[model.model_id]) - model_aliases.append(ModelWithAliases(model, alias_list)) + model_aliases.append(ModelWithAliases(model, async_model, alias_list)) pm.hook.register_models(register=register) @@ -136,30 +136,66 @@ def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]: return model_aliases +def get_async_model_aliases() -> Dict[str, AsyncModel]: + async_model_aliases = {} + for model_with_aliases in get_models_with_aliases(): + if model_with_aliases.async_model: + for alias in model_with_aliases.aliases: + async_model_aliases[alias] = model_with_aliases.async_model + async_model_aliases[model_with_aliases.model.model_id] = ( + model_with_aliases.async_model + ) + return async_model_aliases + + def get_model_aliases() -> Dict[str, Model]: model_aliases = {} for model_with_aliases in get_models_with_aliases(): - for alias in model_with_aliases.aliases: - model_aliases[alias] = model_with_aliases.model - model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model + if model_with_aliases.model: + for alias in model_with_aliases.aliases: + model_aliases[alias] = model_with_aliases.model + model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model return model_aliases -def get_async_model(model_id: str) -> AsyncModel: - return get_model(model_id).get_async_model() - - class UnknownModelError(KeyError): pass +def get_async_model(name: Optional[str] = None) -> AsyncModel: + aliases = get_async_model_aliases() + name = name or get_default_model() + try: + return aliases[name] + except KeyError: + # Does a sync model exist? + sync_model = None + try: + sync_model = get_model(name) + except UnknownModelError: + pass + if sync_model: + raise UnknownModelError("Unknown async model (sync model exists): " + name) + else: + raise UnknownModelError("Unknown model: " + name) + + def get_model(name: Optional[str] = None) -> Model: aliases = get_model_aliases() name = name or get_default_model() try: return aliases[name] except KeyError: - raise UnknownModelError("Unknown model: " + name) + # Does an async model exist? + async_model = None + try: + async_model = get_async_model(name) + except UnknownModelError: + pass + if async_model: + raise UnknownModelError("Unknown model (async model exists): " + name) + else: + raise UnknownModelError("Unknown model: " + name) def get_key( diff --git a/llm/cli.py b/llm/cli.py index 762ef8f5..663588ef 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -333,8 +333,8 @@ def read_prompt(): model = get_async_model(model_id) else: model = get_model(model_id) - except KeyError: - raise click.ClickException("'{}' is not a known model".format(model_id)) + except UnknownModelError as ex: + raise click.ClickException(ex) # Provide the API key, if one is needed and has been provided if model.needs_key: diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 777bd346..301481c7 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -23,21 +23,43 @@ @hookimpl def register_models(register): - register(Chat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt")) - register(Chat("gpt-3.5-turbo-16k"), aliases=("chatgpt-16k", "3.5-16k")) - register(Chat("gpt-4"), aliases=("4", "gpt4")) - register(Chat("gpt-4-32k"), aliases=("4-32k",)) + register( + Chat("gpt-3.5-turbo"), AsyncChat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt") + ) + register( + Chat("gpt-3.5-turbo-16k"), + AsyncChat("gpt-3.5-turbo-16k"), + aliases=("chatgpt-16k", "3.5-16k"), + ) + register(Chat("gpt-4"), AsyncChat("gpt-4"), aliases=("4", "gpt4")) + register(Chat("gpt-4-32k"), AsyncChat("gpt-4-32k"), aliases=("4-32k",)) # GPT-4 Turbo models - register(Chat("gpt-4-1106-preview")) - register(Chat("gpt-4-0125-preview")) - register(Chat("gpt-4-turbo-2024-04-09")) - register(Chat("gpt-4-turbo"), aliases=("gpt-4-turbo-preview", "4-turbo", "4t")) + register(Chat("gpt-4-1106-preview"), AsyncChat("gpt-4-1106-preview")) + register(Chat("gpt-4-0125-preview"), AsyncChat("gpt-4-0125-preview")) + register(Chat("gpt-4-turbo-2024-04-09"), AsyncChat("gpt-4-turbo-2024-04-09")) + register( + Chat("gpt-4-turbo"), + AsyncChat("gpt-4-turbo"), + aliases=("gpt-4-turbo-preview", "4-turbo", "4t"), + ) # GPT-4o - register(Chat("gpt-4o", vision=True), aliases=("4o",)) - register(Chat("gpt-4o-mini", vision=True), aliases=("4o-mini",)) + register( + Chat("gpt-4o", vision=True), AsyncChat("gpt-4o", vision=True), aliases=("4o",) + ) + register( + Chat("gpt-4o-mini", vision=True), + AsyncChat("gpt-4o", vision=True), + aliases=("4o-mini",), + ) # o1 - register(Chat("o1-preview", can_stream=False, allows_system_prompt=False)) - register(Chat("o1-mini", can_stream=False, allows_system_prompt=False)) + register( + Chat("o1-preview", can_stream=False, allows_system_prompt=False), + AsyncChat("o1-preview", can_stream=False, allows_system_prompt=False), + ) + register( + Chat("o1-mini", can_stream=False, allows_system_prompt=False), + AsyncChat("o1-mini", can_stream=False, allows_system_prompt=False), + ) # The -instruct completion model register( Completion("gpt-3.5-turbo-instruct", default_max_tokens=256), diff --git a/llm/models.py b/llm/models.py index ec83b47b..0e5bd10f 100644 --- a/llm/models.py +++ b/llm/models.py @@ -387,12 +387,10 @@ def from_row(cls, db, row): return response def __repr__(self): - text = '... not yet awaited ...' + text = "... not yet awaited ..." if self._done: text = "".join(self._chunks) - return "".format( - self.prompt.prompt, text - ) + return "".format(self.prompt.prompt, text) class Options(BaseModel): @@ -695,6 +693,7 @@ def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float @dataclass class ModelWithAliases: model: Model + async_model: AsyncModel aliases: Set[str] From 2b6f5ccc36e8a8cd2c2f4344c11999b20ef1106a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 5 Nov 2024 20:22:13 -0800 Subject: [PATCH 05/28] Refactor Chat and AsyncChat to use _Shared base class Refs https://github.com/simonw/llm/issues/507#issuecomment-2458692338 --- llm/default_plugins/openai_models.py | 79 +++++++++++++--------------- 1 file changed, 37 insertions(+), 42 deletions(-) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 301481c7..d9f1b15f 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -270,15 +270,11 @@ def validate_logit_bias(cls, logit_bias): return validated_logit_bias -class Chat(Model): +class _Shared: needs_key = "openai" key_env_var = "OPENAI_API_KEY" - default_max_tokens = None - def get_async_model(self): - return AsyncChat(self.model_id, self.key) - class Options(SharedOptions): json_object: Optional[bool] = Field( description="Output a valid JSON object {...}. Prompt must mention JSON.", @@ -370,40 +366,6 @@ def build_messages(self, prompt, conversation): messages.append({"role": "user", "content": attachment_message}) return messages - def execute(self, prompt, stream, response, conversation=None): - if prompt.system and not self.allows_system_prompt: - raise NotImplementedError("Model does not support system prompts") - messages = self.build_messages(prompt, conversation) - kwargs = self.build_kwargs(prompt, stream) - client = self.get_client() - if stream: - completion = client.chat.completions.create( - model=self.model_name or self.model_id, - messages=messages, - stream=True, - **kwargs, - ) - chunks = [] - for chunk in completion: - chunks.append(chunk) - try: - content = chunk.choices[0].delta.content - except IndexError: - content = None - if content is not None: - yield content - response.response_json = remove_dict_none_values(combine_chunks(chunks)) - else: - completion = client.chat.completions.create( - model=self.model_name or self.model_id, - messages=messages, - stream=False, - **kwargs, - ) - response.response_json = remove_dict_none_values(completion.model_dump()) - yield completion.choices[0].message.content - response._prompt_json = redact_data_urls({"messages": messages}) - def get_client(self, async_=False): kwargs = {} if self.api_base: @@ -441,10 +403,43 @@ def build_kwargs(self, prompt, stream): return kwargs -class AsyncChat(AsyncModel, Chat): - needs_key = "openai" - key_env_var = "OPENAI_API_KEY" +class Chat(_Shared, Model): + def execute(self, prompt, stream, response, conversation=None): + if prompt.system and not self.allows_system_prompt: + raise NotImplementedError("Model does not support system prompts") + messages = self.build_messages(prompt, conversation) + kwargs = self.build_kwargs(prompt, stream) + client = self.get_client() + if stream: + completion = client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=True, + **kwargs, + ) + chunks = [] + for chunk in completion: + chunks.append(chunk) + try: + content = chunk.choices[0].delta.content + except IndexError: + content = None + if content is not None: + yield content + response.response_json = remove_dict_none_values(combine_chunks(chunks)) + else: + completion = client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=False, + **kwargs, + ) + response.response_json = remove_dict_none_values(completion.model_dump()) + yield completion.choices[0].message.content + response._prompt_json = redact_data_urls({"messages": messages}) + +class AsyncChat(_Shared, AsyncModel): async def execute(self, prompt, stream, response, conversation=None): if prompt.system and not self.allows_system_prompt: raise NotImplementedError("Model does not support system prompts") From d9ed54fb8c21477ffa0d07e6c0cdbdad6d20c468 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 6 Nov 2024 03:24:43 -0800 Subject: [PATCH 06/28] fixed function name --- llm/default_plugins/openai_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 9bc9ce05..5796c8f8 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -526,7 +526,7 @@ async def execute(self, prompt, stream, response, conversation=None): ) response.response_json = remove_dict_none_values(completion.model_dump()) yield completion.choices[0].message.content - response._prompt_json = redact_data_urls({"messages": messages}) + response._prompt_json = redact_data({"messages": messages}) class Completion(Chat): From 55830df9c1e05d38f1f94659771bad41097ffa2d Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 6 Nov 2024 16:34:02 -0800 Subject: [PATCH 07/28] Fix for infinite loop --- llm/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llm/__init__.py b/llm/__init__.py index 18ea3055..285229dc 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -171,7 +171,7 @@ def get_async_model(name: Optional[str] = None) -> AsyncModel: # Does a sync model exist? sync_model = None try: - sync_model = get_model(name) + sync_model = get_model(name, _skip_async=True) except UnknownModelError: pass if sync_model: @@ -180,13 +180,15 @@ def get_async_model(name: Optional[str] = None) -> AsyncModel: raise UnknownModelError("Unknown model: " + name) -def get_model(name: Optional[str] = None) -> Model: +def get_model(name: Optional[str] = None, _skip_async: bool = False) -> Model: aliases = get_model_aliases() name = name or get_default_model() try: return aliases[name] except KeyError: # Does an async model exist? + if _skip_async: + raise UnknownModelError("Unknown model: " + name) async_model = None try: async_model = get_async_model(name) From 5466a18d08eeb1d290d747c46984d2956ee03cc3 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 6 Nov 2024 16:34:18 -0800 Subject: [PATCH 08/28] Applied Black --- llm/default_plugins/openai_models.py | 43 +++++++--------------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 5796c8f8..2559fd2b 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -25,57 +25,36 @@ def register_models(register): # GPT-4o register( - Chat("gpt-4o", vision=True), - AsyncChat("gpt-4o", vision=True), - aliases=("4o",) + Chat("gpt-4o", vision=True), AsyncChat("gpt-4o", vision=True), aliases=("4o",) ) register( Chat("gpt-4o-mini", vision=True), AsyncChat("gpt-4o-mini", vision=True), - aliases=("4o-mini",) + aliases=("4o-mini",), ) register( Chat("gpt-4o-audio-preview", audio=True), - AsyncChat("gpt-4o-audio-preview", audio=True) + AsyncChat("gpt-4o-audio-preview", audio=True), ) # 3.5 and 4 register( - Chat("gpt-3.5-turbo"), - AsyncChat("gpt-3.5-turbo"), - aliases=("3.5", "chatgpt") + Chat("gpt-3.5-turbo"), AsyncChat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt") ) register( Chat("gpt-3.5-turbo-16k"), AsyncChat("gpt-3.5-turbo-16k"), - aliases=("chatgpt-16k", "3.5-16k") - ) - register( - Chat("gpt-4"), - AsyncChat("gpt-4"), - aliases=("4", "gpt4") - ) - register( - Chat("gpt-4-32k"), - AsyncChat("gpt-4-32k"), - aliases=("4-32k",) + aliases=("chatgpt-16k", "3.5-16k"), ) + register(Chat("gpt-4"), AsyncChat("gpt-4"), aliases=("4", "gpt4")) + register(Chat("gpt-4-32k"), AsyncChat("gpt-4-32k"), aliases=("4-32k",)) # GPT-4 Turbo models - register( - Chat("gpt-4-1106-preview"), - AsyncChat("gpt-4-1106-preview") - ) - register( - Chat("gpt-4-0125-preview"), - AsyncChat("gpt-4-0125-preview") - ) - register( - Chat("gpt-4-turbo-2024-04-09"), - AsyncChat("gpt-4-turbo-2024-04-09") - ) + register(Chat("gpt-4-1106-preview"), AsyncChat("gpt-4-1106-preview")) + register(Chat("gpt-4-0125-preview"), AsyncChat("gpt-4-0125-preview")) + register(Chat("gpt-4-turbo-2024-04-09"), AsyncChat("gpt-4-turbo-2024-04-09")) register( Chat("gpt-4-turbo"), AsyncChat("gpt-4-turbo"), - aliases=("gpt-4-turbo-preview", "4-turbo", "4t") + aliases=("gpt-4-turbo-preview", "4-turbo", "4t"), ) # o1 register( From 330952867b8c182bbf9b9c9e92d2a7889ac50f53 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 7 Nov 2024 00:35:35 +0000 Subject: [PATCH 09/28] Ran cog --- docs/help.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/help.md b/docs/help.md index 0e28494a..a437339b 100644 --- a/docs/help.md +++ b/docs/help.md @@ -116,6 +116,7 @@ Options: --cid, --conversation TEXT Continue the conversation with the given ID. --key TEXT API key to use --save TEXT Save prompt with this template name + --async Run prompt asynchronously --help Show this message and exit. ``` From d310df5208c6d1f4cf75c209a7eab9e0ef06ea04 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 6 Nov 2024 16:43:02 -0800 Subject: [PATCH 10/28] Applied Black --- tests/test_chat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index 01b2a0c0..f4e15861 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -80,7 +80,10 @@ def test_chat_basic(mock_model, logs_db): # Now continue that conversation mock_model.enqueue(["continued"]) result2 = runner.invoke( - llm.cli.cli, ["chat", "-m", "mock", "-c"], input="Continue\nquit\n" + llm.cli.cli, + ["chat", "-m", "mock", "-c"], + input="Continue\nquit\n", + catch_exceptions=False, ) assert result2.exit_code == 0 assert result2.output == ( From 61dfc1db6aaec9b6852efe0e86b64fe5c9ba6818 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 6 Nov 2024 16:43:30 -0800 Subject: [PATCH 11/28] Add Response.from_row() classmethod back again It does not matter that this is a blocking call, since it is a classmethod --- llm/models.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/llm/models.py b/llm/models.py index 0e5bd10f..4d58cf25 100644 --- a/llm/models.py +++ b/llm/models.py @@ -202,6 +202,43 @@ def datetime_utc(self) -> str: self._force() return self._start_utcnow.isoformat() + @classmethod + def from_row(cls, db, row): + from llm import get_model + + model = get_model(row["model"]) + + response = cls( + model=model, + prompt=Prompt( + prompt=row["prompt"], + model=model, + attachments=[], + system=row["system"], + options=model.Options(**json.loads(row["options_json"])), + ), + stream=False, + ) + response.id = row["id"] + response._prompt_json = json.loads(row["prompt_json"] or "null") + response.response_json = json.loads(row["response_json"] or "null") + response._done = True + response._chunks = [row["response"]] + # Attachments + response.attachments = [ + Attachment.from_row(arow) + for arow in db.query( + """ + select attachments.* from attachments + join prompt_attachments on attachments.id = prompt_attachments.attachment_id + where prompt_attachments.response_id = ? + order by prompt_attachments."order" + """, + [row["id"]], + ) + ] + return response + class Response(_BaseResponse["Model", Optional["Conversation"]]): def _force(self): From b3a6ec7b1dd6197fd8aaf06f59833994ec6c17b4 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 6 Nov 2024 17:05:49 -0800 Subject: [PATCH 12/28] Made mypy happy with llm/models.py --- llm/models.py | 178 ++++++++++++++++++++------------------------------ 1 file changed, 70 insertions(+), 108 deletions(-) diff --git a/llm/models.py b/llm/models.py index 4d58cf25..25a016b7 100644 --- a/llm/models.py +++ b/llm/models.py @@ -26,6 +26,13 @@ from pydantic import BaseModel from ulid import ULID +ModelT = TypeVar("ModelT", bound=Union["Model", "AsyncModel"]) +ConversationT = TypeVar( + "ConversationT", bound=Optional[Union["Conversation", "AsyncConversation"]] +) +ResponseT = TypeVar("ResponseT") + + CONVERSATION_NAME_LENGTH = 32 @@ -131,7 +138,7 @@ def prompt( system: Optional[str] = None, stream: bool = True, **options - ): + ) -> "Response": return Response( Prompt( prompt, @@ -156,10 +163,44 @@ def from_row(cls, row): ) -ModelT = TypeVar("ModelT", bound=Union["Model", "AsyncModel"]) -ConversationT = TypeVar( - "ConversationT", bound=Optional[Union["Conversation", "AsyncConversation"]] -) +@dataclass +class AsyncConversation: + model: "AsyncModel" + id: str = field(default_factory=lambda: str(ULID()).lower()) + name: Optional[str] = None + responses: List["AsyncResponse"] = field(default_factory=list) + + def prompt( + self, + prompt: Optional[str], + *, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + stream: bool = True, + **options + ) -> "AsyncResponse": + return AsyncResponse( + Prompt( + prompt, + model=self.model, + attachments=attachments, + system=system, + options=self.model.Options(**options), + ), + self.model, + stream, + conversation=self, + ) + + @classmethod + def from_row(cls, row): + from llm import get_model + + return cls( + model=get_model(row["model"]), + id=row["id"], + name=row["name"], + ) class _BaseResponse(ABC, Generic[ModelT, ConversationT]): @@ -168,7 +209,7 @@ def __init__( prompt: Prompt, model: ModelT, stream: bool, - conversation: ConversationT = None, + conversation: Optional[ConversationT] = None, ): self.prompt = prompt self._prompt_json = None @@ -183,25 +224,6 @@ def __init__( self._end: Optional[float] = None self._start_utcnow: Optional[datetime.datetime] = None - def __str__(self) -> str: - return self.text() - - def text(self) -> str: - self._force() - return "".join(self._chunks) - - def json(self) -> Optional[Dict[str, Any]]: - self._force() - return self.response_json - - def duration_ms(self) -> int: - self._force() - return int((self._end - self._start) * 1000) - - def datetime_utc(self) -> str: - self._force() - return self._start_utcnow.isoformat() - @classmethod def from_row(cls, db, row): from llm import get_model @@ -241,10 +263,29 @@ def from_row(cls, db, row): class Response(_BaseResponse["Model", Optional["Conversation"]]): + def __str__(self) -> str: + return self.text() + def _force(self): if not self._done: list(self) + def text(self) -> str: + self._force() + return "".join(self._chunks) + + def json(self) -> Optional[Dict[str, Any]]: + self._force() + return self.response_json + + def duration_ms(self) -> int: + self._force() + return int(((self._end or 0) - (self._start or 0)) * 1000) + + def datetime_utc(self) -> str: + self._force() + return self._start_utcnow.isoformat() if self._start_utcnow else "" + def __iter__(self) -> Iterator[str]: self._start = time.monotonic() self._start_utcnow = datetime.datetime.utcnow() @@ -332,7 +373,7 @@ async def __aiter__(self) -> AsyncIterator[str]: yield chunk return - async for chunk in self.model.execute( + async for chunk in await self.model.execute( self.prompt, stream=self.stream, response=self, @@ -356,16 +397,16 @@ async def json(self) -> Optional[Dict[str, Any]]: async def duration_ms(self) -> int: await self._force() - return int((self._end - self._start) * 1000) + return int(((self._end or 0) - (self._start or 0)) * 1000) async def datetime_utc(self) -> str: await self._force() - return self._start_utcnow.isoformat() + return self._start_utcnow.isoformat() if self._start_utcnow else "" @classmethod def fake( cls, - model: "Model", + model: "AsyncModel", prompt: str, *attachments: List[Attachment], system: str, @@ -468,10 +509,6 @@ def get_key(self): raise NeedsKeyException(message) -ResponseT = TypeVar("ResponseT") -ConversationT = TypeVar("ConversationT") - - class _BaseModel(ABC, _get_key_mixin, Generic[ResponseT, ConversationT]): model_id: str @@ -597,81 +634,6 @@ def response(self, prompt: Prompt, stream: bool = True) -> "AsyncResponse": return AsyncResponse(prompt, self, stream) -class Model(ABC, _get_key_mixin): - model_id: str - - # API key handling - key: Optional[str] = None - needs_key: Optional[str] = None - key_env_var: Optional[str] = None - - # Model characteristics - can_stream: bool = False - attachment_types: Set = set() - - class Options(_Options): - pass - - def conversation(self): - return Conversation(model=self) - - @abstractmethod - def execute( - self, - prompt: Prompt, - stream: bool, - response: Response, - conversation: Optional[Conversation], - ) -> Iterator[str]: - """ - Execute a prompt and yield chunks of text, or yield a single big chunk. - Any additional useful information about the execution should be assigned to the response. - """ - pass - - def prompt( - self, - prompt: str, - *, - attachments: Optional[List[Attachment]] = None, - system: Optional[str] = None, - stream: bool = True, - **options - ): - # Validate attachments - if attachments and not self.attachment_types: - raise ValueError( - "This model does not support attachments, but some were provided" - ) - for attachment in attachments or []: - attachment_type = attachment.resolve_type() - if attachment_type not in self.attachment_types: - raise ValueError( - "This model does not support attachments of type '{}', only {}".format( - attachment_type, ", ".join(self.attachment_types) - ) - ) - return self.response( - Prompt( - prompt, - attachments=attachments, - system=system, - model=self, - options=self.Options(**options), - ), - stream=stream, - ) - - def response(self, prompt: Prompt, stream: bool = True) -> Response: - return Response(prompt, self, stream) - - def __str__(self) -> str: - return "{}: {}".format(self.__class__.__name__, self.model_id) - - def __repr__(self): - return "".format(self.model_id) - - class EmbeddingModel(ABC, _get_key_mixin): model_id: str key: Optional[str] = None From 91732d00bef0c2404ef25353018005361de6331f Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 6 Nov 2024 17:08:47 -0800 Subject: [PATCH 13/28] mypy fixes for openai_models.py I am unhappy with this, had to duplicate some code. --- llm/default_plugins/openai_models.py | 30 ++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 2559fd2b..a4ebadcc 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -299,16 +299,6 @@ def _attachment(attachment): class _Shared: - needs_key = "openai" - key_env_var = "OPENAI_API_KEY" - default_max_tokens = None - - class Options(SharedOptions): - json_object: Optional[bool] = Field( - description="Output a valid JSON object {...}. Prompt must mention JSON.", - default=None, - ) - def __init__( self, model_id, @@ -437,6 +427,16 @@ def build_kwargs(self, prompt, stream): class Chat(_Shared, Model): + needs_key = "openai" + key_env_var = "OPENAI_API_KEY" + default_max_tokens = None + + class Options(SharedOptions): + json_object: Optional[bool] = Field( + description="Output a valid JSON object {...}. Prompt must mention JSON.", + default=None, + ) + def execute(self, prompt, stream, response, conversation=None): if prompt.system and not self.allows_system_prompt: raise NotImplementedError("Model does not support system prompts") @@ -473,6 +473,16 @@ def execute(self, prompt, stream, response, conversation=None): class AsyncChat(_Shared, AsyncModel): + needs_key = "openai" + key_env_var = "OPENAI_API_KEY" + default_max_tokens = None + + class Options(SharedOptions): + json_object: Optional[bool] = Field( + description="Output a valid JSON object {...}. Prompt must mention JSON.", + default=None, + ) + async def execute(self, prompt, stream, response, conversation=None): if prompt.system and not self.allows_system_prompt: raise NotImplementedError("Model does not support system prompts") From 2e1045d8eee089ce8d2c53efe90a3e415e3dfbc4 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 6 Nov 2024 17:31:47 -0800 Subject: [PATCH 14/28] First test for AsyncModel --- llm/models.py | 2 +- setup.py | 1 + tests/conftest.py | 32 ++++++++++++++++++++++++++++++-- tests/test_async.py | 10 ++++++++++ 4 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 tests/test_async.py diff --git a/llm/models.py b/llm/models.py index 25a016b7..1e6c165e 100644 --- a/llm/models.py +++ b/llm/models.py @@ -373,7 +373,7 @@ async def __aiter__(self) -> AsyncIterator[str]: yield chunk return - async for chunk in await self.model.execute( + async for chunk in self.model.execute( self.prompt, stream=self.stream, response=self, diff --git a/setup.py b/setup.py index 6f500815..24b5acd2 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ def get_long_description(): "pytest", "numpy", "pytest-httpx>=0.33.0", + "pytest-asyncio", "cogapp", "mypy>=1.10.0", "black>=24.1.0", diff --git a/tests/conftest.py b/tests/conftest.py index bcdb8854..7eb3dd56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -75,6 +75,29 @@ def execute(self, prompt, stream, response, conversation): break +class AsyncMockModel(llm.AsyncModel): + model_id = "mock" + + def __init__(self): + self.history = [] + self._queue = [] + + def enqueue(self, messages): + assert isinstance(messages, list) + self._queue.append(messages) + + async def execute(self, prompt, stream, response, conversation): + self.history.append((prompt, stream, response, conversation)) + while True: + try: + messages = self._queue.pop(0) + for message in messages: + yield message + break + except IndexError: + break + + class EmbedDemo(llm.EmbeddingModel): model_id = "embed-demo" batch_size = 10 @@ -118,8 +141,13 @@ def mock_model(): return MockModel() +@pytest.fixture +def async_mock_model(): + return AsyncMockModel() + + @pytest.fixture(autouse=True) -def register_embed_demo_model(embed_demo, mock_model): +def register_embed_demo_model(embed_demo, mock_model, async_mock_model): class MockModelsPlugin: __name__ = "MockModelsPlugin" @@ -131,7 +159,7 @@ def register_embedding_models(self, register): @llm.hookimpl def register_models(self, register): - register(mock_model) + register(mock_model, async_model=async_mock_model) pm.register(MockModelsPlugin(), name="undo-mock-models-plugin") try: diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 00000000..c7d3f9d9 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,10 @@ +import pytest + + +@pytest.mark.asyncio +async def test_async_model(async_mock_model): + gathered = [] + async_mock_model.enqueue(["hello world"]) + async for chunk in async_mock_model.prompt("hello"): + gathered.append(chunk) + assert gathered == ["hello world"] From f311dbf619fa3aeccf035fee93eca932b8bed21b Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 7 Nov 2024 18:36:32 -0800 Subject: [PATCH 15/28] Still have not quite got this working --- llm/models.py | 57 +++++++++++++++++++++++++++++++++------------------ pytest.ini | 3 ++- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/llm/models.py b/llm/models.py index 1e6c165e..afd7ad5c 100644 --- a/llm/models.py +++ b/llm/models.py @@ -365,28 +365,45 @@ async def _force(self): async for _ in self: pass - async def __aiter__(self) -> AsyncIterator[str]: - self._start = time.monotonic() - self._start_utcnow = datetime.datetime.utcnow() + def __aiter__(self): + # __aiter__ should return self directly, not be async + return self + + async def __anext__(self) -> str: + if self._start is None: + self._start = time.monotonic() + self._start_utcnow = datetime.datetime.utcnow() if self._done: - for chunk in self._chunks: - yield chunk - return - - async for chunk in self.model.execute( - self.prompt, - stream=self.stream, - response=self, - conversation=self.conversation, - ): - yield chunk - self._chunks.append(chunk) - if self.conversation: - self.conversation.responses.append(self) - self._end = time.monotonic() - self._done = True + if not self._chunks: + raise StopAsyncIteration + chunk = self._chunks.pop(0) + if not self._chunks: + raise StopAsyncIteration + return chunk + try: + iterator = self.model.execute( + self.prompt, + stream=self.stream, + response=self, + conversation=self.conversation, + ) + async for chunk in iterator: + self._chunks.append(chunk) + return chunk + + if self.conversation: + self.conversation.responses.append(self) + self._end = time.monotonic() + self._done = True + + raise StopAsyncIteration + except StopAsyncIteration: + if self.conversation: + self.conversation.responses.append(self) + self._end = time.monotonic() + self._done = True + raise - # Override base methods to make them async async def text(self) -> str: await self._force() return "".join(self._chunks) diff --git a/pytest.ini b/pytest.ini index 8658fc91..ba352d26 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,5 @@ [pytest] filterwarnings = ignore:The `schema` method is deprecated.*:DeprecationWarning - ignore:Support for class-based `config` is deprecated*:DeprecationWarning \ No newline at end of file + ignore:Support for class-based `config` is deprecated*:DeprecationWarning +asyncio_default_fixture_loop_scope = function \ No newline at end of file From 4f3e82a172e8418a9cd26ac20c4d63cd19e80a6e Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 12 Nov 2024 21:10:03 -0800 Subject: [PATCH 16/28] Fix for not loading plugins during tests, refs #626 --- llm/__init__.py | 5 ++++- llm/cli.py | 4 ++++ llm/plugins.py | 57 ++++++++++++++++++++++++++++--------------------- 3 files changed, 41 insertions(+), 25 deletions(-) diff --git a/llm/__init__.py b/llm/__init__.py index 285229dc..16ff9aed 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -17,7 +17,7 @@ ) from .embeddings import Collection from .templates import Template -from .plugins import pm +from .plugins import pm, load_plugins import click from typing import Dict, List, Optional import json @@ -82,6 +82,7 @@ def register(model, async_model=None, aliases=None): alias_list.extend(extra_model_aliases[model.model_id]) model_aliases.append(ModelWithAliases(model, async_model, alias_list)) + load_plugins() pm.hook.register_models(register=register) return model_aliases @@ -104,6 +105,7 @@ def register(model, aliases=None): alias_list.extend(extra_model_aliases[model.model_id]) model_aliases.append(EmbeddingModelWithAliases(model, alias_list)) + load_plugins() pm.hook.register_embedding_models(register=register) return model_aliases @@ -115,6 +117,7 @@ def get_embedding_models(): def register(model, aliases=None): models.append(model) + load_plugins() pm.hook.register_embedding_models(register=register) return models diff --git a/llm/cli.py b/llm/cli.py index bb9075e5..8fe25144 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1817,6 +1817,10 @@ def render_errors(errors): return "\n".join(output) +from .plugins import load_plugins + +load_plugins() + pm.hook.register_commands(cli=cli) diff --git a/llm/plugins.py b/llm/plugins.py index 933725c7..5c00b9e6 100644 --- a/llm/plugins.py +++ b/llm/plugins.py @@ -12,27 +12,36 @@ LLM_LOAD_PLUGINS = os.environ.get("LLM_LOAD_PLUGINS", None) -if not hasattr(sys, "_called_from_test") and LLM_LOAD_PLUGINS is None: - # Only load plugins if not running tests - pm.load_setuptools_entrypoints("llm") - - -# Load any plugins specified in LLM_LOAD_PLUGINS") -if LLM_LOAD_PLUGINS is not None: - for package_name in [name for name in LLM_LOAD_PLUGINS.split(",") if name.strip()]: - try: - distribution = metadata.distribution(package_name) # Updated call - llm_entry_points = [ - ep for ep in distribution.entry_points if ep.group == "llm" - ] - for entry_point in llm_entry_points: - mod = entry_point.load() - pm.register(mod, name=entry_point.name) - # Ensure name can be found in plugin_to_distinfo later: - pm._plugin_distinfo.append((mod, distribution)) # type: ignore - except metadata.PackageNotFoundError: - sys.stderr.write(f"Plugin {package_name} could not be found\n") - -for plugin in DEFAULT_PLUGINS: - mod = importlib.import_module(plugin) - pm.register(mod, plugin) +_loaded = False + + +def load_plugins(): + global _loaded + if _loaded: + return + _loaded = True + if not hasattr(sys, "_called_from_test") and LLM_LOAD_PLUGINS is None: + # Only load plugins if not running tests + pm.load_setuptools_entrypoints("llm") + + # Load any plugins specified in LLM_LOAD_PLUGINS") + if LLM_LOAD_PLUGINS is not None: + for package_name in [ + name for name in LLM_LOAD_PLUGINS.split(",") if name.strip() + ]: + try: + distribution = metadata.distribution(package_name) # Updated call + llm_entry_points = [ + ep for ep in distribution.entry_points if ep.group == "llm" + ] + for entry_point in llm_entry_points: + mod = entry_point.load() + pm.register(mod, name=entry_point.name) + # Ensure name can be found in plugin_to_distinfo later: + pm._plugin_distinfo.append((mod, distribution)) # type: ignore + except metadata.PackageNotFoundError: + sys.stderr.write(f"Plugin {package_name} could not be found\n") + + for plugin in DEFAULT_PLUGINS: + mod = importlib.import_module(plugin) + pm.register(mod, plugin) From 145b5cdc22b63b052a790cec3b39c1905d324f2c Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 12 Nov 2024 21:11:05 -0800 Subject: [PATCH 17/28] audio/wav not audio/wave, refs #603 --- docs/plugins/advanced-model-plugins.md | 2 +- docs/usage.md | 2 +- llm/default_plugins/openai_models.py | 4 ++-- tests/test_cli_openai_models.py | 7 +++---- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/plugins/advanced-model-plugins.md b/docs/plugins/advanced-model-plugins.md index a201237c..b9a16885 100644 --- a/docs/plugins/advanced-model-plugins.md +++ b/docs/plugins/advanced-model-plugins.md @@ -79,7 +79,7 @@ def _attachment(attachment): if attachment.resolve_type().startswith("image/"): return {"type": "image_url", "image_url": {"url": url}} else: - format_ = "wav" if attachment.resolve_type() == "audio/wave" else "mp3" + format_ = "wav" if attachment.resolve_type() == "audio/wav" else "mp3" return { "type": "input_audio", "input_audio": { diff --git a/docs/usage.md b/docs/usage.md index 942f2d6a..1160985d 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -305,7 +305,7 @@ OpenAI Chat: gpt-4o-audio-preview seed: int json_object: boolean Attachment types: - audio/mpeg, audio/wave + audio/mpeg, audio/wav OpenAI Chat: gpt-3.5-turbo (aliases: 3.5, chatgpt) Options: temperature: float diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index a4ebadcc..f9dc6b30 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -288,7 +288,7 @@ def _attachment(attachment): if attachment.resolve_type().startswith("image/"): return {"type": "image_url", "image_url": {"url": url}} else: - format_ = "wav" if attachment.resolve_type() == "audio/wave" else "mp3" + format_ = "wav" if attachment.resolve_type() == "audio/wav" else "mp3" return { "type": "input_audio", "input_audio": { @@ -341,7 +341,7 @@ def __init__( if audio: self.attachment_types.update( { - "audio/wave", + "audio/wav", "audio/mpeg", } ) diff --git a/tests/test_cli_openai_models.py b/tests/test_cli_openai_models.py index f341e385..7cbab726 100644 --- a/tests/test_cli_openai_models.py +++ b/tests/test_cli_openai_models.py @@ -65,9 +65,8 @@ def test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype): method="HEAD", url=f"https://www.example.com/example.{filetype}", content=b"binary-data", - headers={"Content-Type": "audio/mpeg" if filetype == "mp3" else "audio/wave"}, + headers={"Content-Type": "audio/mpeg" if filetype == "mp3" else "audio/wav"}, ) - # Another mock for the correct model if model == "gpt-4o-audio-preview": httpx_mock.add_response( method="POST", @@ -116,7 +115,7 @@ def test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype): url=f"https://www.example.com/example.{filetype}", content=b"binary-data", headers={ - "Content-Type": "audio/mpeg" if filetype == "mp3" else "audio/wave" + "Content-Type": "audio/mpeg" if filetype == "mp3" else "audio/wav" }, ) runner = CliRunner() @@ -140,7 +139,7 @@ def test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype): ) else: assert result.exit_code == 1 - long = "audio/mpeg" if filetype == "mp3" else "audio/wave" + long = "audio/mpeg" if filetype == "mp3" else "audio/wav" assert ( f"This model does not support attachments of type '{long}'" in result.output ) From 8ab5ea3c65812165baf3dc2a83377e6c8c150372 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 12 Nov 2024 21:41:06 -0800 Subject: [PATCH 18/28] Black and mypy and ruff all happy --- llm/cli.py | 4 +--- llm/default_plugins/openai_models.py | 6 ++++-- llm/models.py | 31 ++++++++++++++-------------- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index 8fe25144..5472c10d 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -31,7 +31,7 @@ ) from .migrations import migrate -from .plugins import pm +from .plugins import pm, load_plugins import base64 import httpx import pathlib @@ -1817,8 +1817,6 @@ def render_errors(errors): return "\n".join(output) -from .plugins import load_plugins - load_plugins() pm.hook.register_commands(cli=cli) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index f9dc6b30..e0f625a8 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -16,7 +16,7 @@ from pydantic.fields import Field from pydantic.class_validators import validator as field_validator # type: ignore [no-redef] -from typing import List, Iterable, Iterator, Optional, Union +from typing import AsyncGenerator, List, Iterable, Iterator, Optional, Union import json import yaml @@ -483,7 +483,9 @@ class Options(SharedOptions): default=None, ) - async def execute(self, prompt, stream, response, conversation=None): + async def execute( + self, prompt, stream, response, conversation=None + ) -> AsyncGenerator[str, None]: if prompt.system and not self.allows_system_prompt: raise NotImplementedError("Model does not support system prompts") messages = self.build_messages(prompt, conversation) diff --git a/llm/models.py b/llm/models.py index afd7ad5c..eba8baf8 100644 --- a/llm/models.py +++ b/llm/models.py @@ -10,7 +10,7 @@ import time from typing import ( Any, - AsyncIterator, + AsyncGenerator, Dict, Generic, Iterable, @@ -380,23 +380,22 @@ async def __anext__(self) -> str: if not self._chunks: raise StopAsyncIteration return chunk - try: - iterator = self.model.execute( + + # Get and store the generator if we don't have it yet + if not hasattr(self, "_generator"): + generator = self.model.execute( self.prompt, stream=self.stream, response=self, conversation=self.conversation, ) - async for chunk in iterator: - self._chunks.append(chunk) - return chunk - - if self.conversation: - self.conversation.responses.append(self) - self._end = time.monotonic() - self._done = True + self._generator = generator - raise StopAsyncIteration + # Use the generator + try: + chunk = await self._generator.__anext__() + self._chunks.append(chunk) + return chunk except StopAsyncIteration: if self.conversation: self.conversation.responses.append(self) @@ -619,12 +618,12 @@ async def execute( stream: bool, response: "AsyncResponse", conversation: Optional["AsyncConversation"], - ) -> AsyncIterator[str]: + ) -> AsyncGenerator[str, None]: """ - Execute a prompt and yield chunks of text, or yield a single big chunk. - Any additional useful information about the execution should be assigned to the response. + Returns an async generator that executes the prompt and yields chunks of text, + or yields a single big chunk. """ - pass + yield "" def prompt( self, From c4a75833d0ae554c5157fdeec8c885dd47fe002f Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Nov 2024 06:32:28 -0800 Subject: [PATCH 19/28] Refactor to avoid generics --- llm/models.py | 272 ++++++++++++++++++++------------------------------ 1 file changed, 109 insertions(+), 163 deletions(-) diff --git a/llm/models.py b/llm/models.py index 6f5fcef9..0437d17b 100644 --- a/llm/models.py +++ b/llm/models.py @@ -11,7 +11,6 @@ Any, AsyncGenerator, Dict, - Generic, Iterable, Iterator, List, @@ -113,7 +112,7 @@ def __init__( attachments=None, system=None, prompt_json=None, - options=None + options=None, ): self.prompt = prompt self.model = model @@ -124,12 +123,25 @@ def __init__( @dataclass -class Conversation: - model: "Model" +class _BaseConversation: + model: "_BaseModel" id: str = field(default_factory=lambda: str(ULID()).lower()) name: Optional[str] = None - responses: List["Response"] = field(default_factory=list) + responses: List["_BaseResponse"] = field(default_factory=list) + + @classmethod + def from_row(cls, row): + from llm import get_model + + return cls( + model=get_model(row["model"]), + id=row["id"], + name=row["name"], + ) + +@dataclass +class Conversation(_BaseConversation): def prompt( self, prompt: Optional[str], @@ -137,7 +149,7 @@ def prompt( attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, stream: bool = True, - **options + **options, ) -> "Response": return Response( Prompt( @@ -152,24 +164,9 @@ def prompt( conversation=self, ) - @classmethod - def from_row(cls, row): - from llm import get_model - - return cls( - model=get_model(row["model"]), - id=row["id"], - name=row["name"], - ) - @dataclass -class AsyncConversation: - model: "AsyncModel" - id: str = field(default_factory=lambda: str(ULID()).lower()) - name: Optional[str] = None - responses: List["AsyncResponse"] = field(default_factory=list) - +class AsyncConversation(_BaseConversation): def prompt( self, prompt: Optional[str], @@ -177,7 +174,7 @@ def prompt( attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, stream: bool = True, - **options + **options, ) -> "AsyncResponse": return AsyncResponse( Prompt( @@ -192,24 +189,20 @@ def prompt( conversation=self, ) - @classmethod - def from_row(cls, row): - from llm import get_model - return cls( - model=get_model(row["model"]), - id=row["id"], - name=row["name"], - ) +class _BaseResponse: + """Base response class shared between sync and async responses""" + prompt: "Prompt" + stream: bool + conversation: Optional["_BaseConversation"] = None -class _BaseResponse(ABC, Generic[ModelT, ConversationT]): def __init__( self, prompt: Prompt, - model: ModelT, + model: "_BaseModel", stream: bool, - conversation: Optional[ConversationT] = None, + conversation: Optional[_BaseConversation] = None, ): self.prompt = prompt self._prompt_json = None @@ -261,49 +254,6 @@ def from_row(cls, db, row): ] return response - -class Response(_BaseResponse["Model", Optional["Conversation"]]): - def __str__(self) -> str: - return self.text() - - def _force(self): - if not self._done: - list(self) - - def text(self) -> str: - self._force() - return "".join(self._chunks) - - def json(self) -> Optional[Dict[str, Any]]: - self._force() - return self.response_json - - def duration_ms(self) -> int: - self._force() - return int(((self._end or 0) - (self._start or 0)) * 1000) - - def datetime_utc(self) -> str: - self._force() - return self._start_utcnow.isoformat() if self._start_utcnow else "" - - def __iter__(self) -> Iterator[str]: - self._start = time.monotonic() - self._start_utcnow = datetime.datetime.utcnow() - if self._done: - yield from self._chunks - for chunk in self.model.execute( - self.prompt, - stream=self.stream, - response=self, - conversation=self.conversation, - ): - yield chunk - self._chunks.append(chunk) - if self.conversation: - self.conversation.responses.append(self) - self._end = time.monotonic() - self._done = True - def log_to_db(self, db): conversation = self.conversation if not conversation: @@ -359,20 +309,65 @@ def log_to_db(self, db): ) -class AsyncResponse(_BaseResponse["AsyncModel", Optional["AsyncConversation"]]): - async def _force(self): +class Response(_BaseResponse): + model: "Model" + conversation: Optional["Conversation"] = None + + def __str__(self) -> str: + return self.text() + + def _force(self): if not self._done: - async for _ in self: - pass + list(self) + + def text(self) -> str: + self._force() + return "".join(self._chunks) + + def json(self) -> Optional[Dict[str, Any]]: + self._force() + return self.response_json + + def duration_ms(self) -> int: + self._force() + return int(((self._end or 0) - (self._start or 0)) * 1000) + + def datetime_utc(self) -> str: + self._force() + return self._start_utcnow.isoformat() if self._start_utcnow else "" + + def __iter__(self) -> Iterator[str]: + self._start = time.monotonic() + self._start_utcnow = datetime.datetime.utcnow() + if self._done: + yield from self._chunks + return + + for chunk in self.model.execute( + self.prompt, + stream=self.stream, + response=self, + conversation=self.conversation, + ): + yield chunk + self._chunks.append(chunk) + + if self.conversation: + self.conversation.responses.append(self) + self._end = time.monotonic() + self._done = True + + +class AsyncResponse(_BaseResponse): + model: "AsyncModel" + conversation: Optional["AsyncConversation"] = None def __aiter__(self): - # __aiter__ should return self directly, not be async + self._start = time.monotonic() + self._start_utcnow = datetime.datetime.utcnow() return self async def __anext__(self) -> str: - if self._start is None: - self._start = time.monotonic() - self._start_utcnow = datetime.datetime.utcnow() if self._done: if not self._chunks: raise StopAsyncIteration @@ -381,17 +376,14 @@ async def __anext__(self) -> str: raise StopAsyncIteration return chunk - # Get and store the generator if we don't have it yet if not hasattr(self, "_generator"): - generator = self.model.execute( + self._generator = self.model.execute( self.prompt, stream=self.stream, response=self, conversation=self.conversation, ) - self._generator = generator - # Use the generator try: chunk = await self._generator.__anext__() self._chunks.append(chunk) @@ -403,6 +395,11 @@ async def __anext__(self) -> str: self._done = True raise + async def _force(self): + if not self._done: + async for _ in self: + pass + async def text(self) -> str: await self._force() return "".join(self._chunks) @@ -426,7 +423,7 @@ def fake( prompt: str, *attachments: List[Attachment], system: str, - response: str + response: str, ): "Utility method to help with writing tests" response_obj = cls( @@ -443,43 +440,6 @@ def fake( response_obj._chunks = [response] return response_obj - @classmethod - def from_row(cls, db, row): - from llm import get_model - - model = get_model(row["model"]) - - response = cls( - model=model, - prompt=Prompt( - prompt=row["prompt"], - model=model, - attachments=[], - system=row["system"], - options=model.Options(**json.loads(row["options_json"])), - ), - stream=False, - ) - response.id = row["id"] - response._prompt_json = json.loads(row["prompt_json"] or "null") - response.response_json = json.loads(row["response_json"] or "null") - response._done = True - response._chunks = [row["response"]] - # Attachments - response.attachments = [ - Attachment.from_row(arow) - for arow in db.query( - """ - select attachments.* from attachments - join prompt_attachments on attachments.id = prompt_attachments.attachment_id - where prompt_attachments.response_id = ? - order by prompt_attachments."order" - """, - [row["id"]], - ) - ] - return response - def __repr__(self): text = "... not yet awaited ..." if self._done: @@ -525,15 +485,11 @@ def get_key(self): raise NeedsKeyException(message) -class _BaseModel(ABC, _get_key_mixin, Generic[ResponseT, ConversationT]): +class _BaseModel(ABC, _get_key_mixin): model_id: str - - # API key handling key: Optional[str] = None needs_key: Optional[str] = None key_env_var: Optional[str] = None - - # Model characteristics can_stream: bool = False attachment_types: Set = set() @@ -543,18 +499,14 @@ class Options(_Options): def _validate_attachments( self, attachments: Optional[List[Attachment]] = None ) -> None: - """Shared attachment validation logic""" if attachments and not self.attachment_types: - raise ValueError( - "This model does not support attachments, but some were provided" - ) + raise ValueError("This model does not support attachments") for attachment in attachments or []: attachment_type = attachment.resolve_type() if attachment_type not in self.attachment_types: raise ValueError( - "This model does not support attachments of type '{}', only {}".format( - attachment_type, ", ".join(self.attachment_types) - ) + f"This model does not support attachments of type '{attachment_type}', " + f"only {', '.join(self.attachment_types)}" ) def __str__(self) -> str: @@ -564,8 +516,8 @@ def __repr__(self): return "<{} '{}'>".format(self.__class__.__name__, self.model_id) -class Model(_BaseModel["Response", "Conversation"]): - def conversation(self) -> "Conversation": +class Model(_BaseModel): + def conversation(self) -> Conversation: return Conversation(model=self) @abstractmethod @@ -573,13 +525,9 @@ def execute( self, prompt: Prompt, stream: bool, - response: "Response", - conversation: Optional["Conversation"], + response: Response, + conversation: Optional[Conversation], ) -> Iterator[str]: - """ - Execute a prompt and yield chunks of text, or yield a single big chunk. - Any additional useful information about the execution should be assigned to the response. - """ pass def prompt( @@ -589,10 +537,10 @@ def prompt( attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, stream: bool = True, - **options - ) -> "Response": + **options, + ) -> Response: self._validate_attachments(attachments) - return self.response( + return Response( Prompt( prompt, attachments=attachments, @@ -600,15 +548,16 @@ def prompt( model=self, options=self.Options(**options), ), - stream=stream, + self, + stream, ) def response(self, prompt: Prompt, stream: bool = True) -> "Response": return Response(prompt, self, stream) -class AsyncModel(_BaseModel["AsyncResponse", "AsyncConversation"]): - def conversation(self) -> "AsyncConversation": +class AsyncModel(_BaseModel): + def conversation(self) -> AsyncConversation: return AsyncConversation(model=self) @abstractmethod @@ -616,13 +565,9 @@ async def execute( self, prompt: Prompt, stream: bool, - response: "AsyncResponse", - conversation: Optional["AsyncConversation"], + response: AsyncResponse, + conversation: Optional[AsyncConversation], ) -> AsyncGenerator[str, None]: - """ - Returns an async generator that executes the prompt and yields chunks of text, - or yields a single big chunk. - """ yield "" def prompt( @@ -632,10 +577,10 @@ def prompt( attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, stream: bool = True, - **options - ) -> "AsyncResponse": + **options, + ) -> AsyncResponse: self._validate_attachments(attachments) - return self.response( + return AsyncResponse( Prompt( prompt, attachments=attachments, @@ -643,7 +588,8 @@ def prompt( model=self, options=self.Options(**options), ), - stream=stream, + self, + stream, ) def response(self, prompt: Prompt, stream: bool = True) -> "AsyncResponse": From 9b1e72047027acd7aaf492a3cd203868eb9dee3c Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Nov 2024 12:01:47 -0800 Subject: [PATCH 20/28] Removed obsolete response() method --- llm/models.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/llm/models.py b/llm/models.py index 0437d17b..3ed61bf3 100644 --- a/llm/models.py +++ b/llm/models.py @@ -552,9 +552,6 @@ def prompt( stream, ) - def response(self, prompt: Prompt, stream: bool = True) -> "Response": - return Response(prompt, self, stream) - class AsyncModel(_BaseModel): def conversation(self) -> AsyncConversation: @@ -592,9 +589,6 @@ def prompt( stream, ) - def response(self, prompt: Prompt, stream: bool = True) -> "AsyncResponse": - return AsyncResponse(prompt, self, stream) - class EmbeddingModel(ABC, _get_key_mixin): model_id: str From 1c83a4edaa168e610a074aa6005b113fb69c2f4d Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Nov 2024 12:09:35 -0800 Subject: [PATCH 21/28] Support text = await async_mock_model.prompt("hello") --- llm/models.py | 3 +++ tests/test_async.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/llm/models.py b/llm/models.py index 3ed61bf3..8539d083 100644 --- a/llm/models.py +++ b/llm/models.py @@ -416,6 +416,9 @@ async def datetime_utc(self) -> str: await self._force() return self._start_utcnow.isoformat() if self._start_utcnow else "" + def __await__(self): + return self.text().__await__() + @classmethod def fake( cls, diff --git a/tests/test_async.py b/tests/test_async.py index c7d3f9d9..db4cf529 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -8,3 +8,7 @@ async def test_async_model(async_mock_model): async for chunk in async_mock_model.prompt("hello"): gathered.append(chunk) assert gathered == ["hello world"] + # Not as an iterator + async_mock_model.enqueue(["hello world"]) + text = await async_mock_model.prompt("hello") + assert text == "hello world" From ceb60d2b919661231464b36c56a2b8e3fb2b0f9c Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Nov 2024 12:14:21 -0800 Subject: [PATCH 22/28] Initial docs for llm.get_async_model() and await model.prompt() Refs #507 --- docs/python-api.md | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/docs/python-api.md b/docs/python-api.md index ae135a68..55387d07 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -99,7 +99,7 @@ print(response.text()) ``` Some models do not use API keys at all. -## Streaming responses +### Streaming responses For models that support it you can stream responses as they are generated, like this: @@ -112,6 +112,31 @@ The `response.text()` method described earlier does this for you - it runs throu If a response has been evaluated, `response.text()` will continue to return the same string. +## Async models + +Some plugins provide async versions of their supported models, suitable for use with Python [asyncio](https://docs.python.org/3/library/asyncio.html). + +To use an async model, use the `llm.get_async_model()` function instead of `llm.get_model()`: + +```python +import llm +model = llm.get_async_model("gpt-4o") +``` +You can then run a prompt using `await model.prompt(...)`: + +```python +result = await model.prompt( + "Five surprising names for a pet pelican" +) +``` +Or use `async for chunk in ...` to stream the response as it is generated: +```python +async for chunk in model.prompt( + "Five surprising names for a pet pelican" +): + print(chunk, end="") +``` + ## Conversations LLM supports *conversations*, where you ask follow-up questions of a model as part of an ongoing conversation. From 5f66149be8c117bbc1c5f22e4d367f6933bc63ba Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Nov 2024 12:25:47 -0800 Subject: [PATCH 23/28] Initial async model plugin creation docs --- docs/plugins/advanced-model-plugins.md | 48 ++++++++++++++++++++++++++ docs/python-api.md | 2 ++ 2 files changed, 50 insertions(+) diff --git a/docs/plugins/advanced-model-plugins.md b/docs/plugins/advanced-model-plugins.md index b9a16885..26e53eb7 100644 --- a/docs/plugins/advanced-model-plugins.md +++ b/docs/plugins/advanced-model-plugins.md @@ -5,6 +5,52 @@ The {ref}`model plugin tutorial ` covers the basics of de This document covers more advanced topics. +## Async models + +Plugins can optionally provide an asynchronous version of their model, suitable for use with Python [asyncio](https://docs.python.org/3/library/asyncio.html). This is particularly useful for remote models accessible by an HTTP API. + +The async version of a model subclasses `llm.AsyncModel` instead of `llm.Model`. It must implement an `async def execute()` async generator method instead of `def execute()`. + +This example shows a subset of the OpenAI default plugin illustrating how this method might work: + + +```python +from typing import AsyncGenerator +import llm + +class MyAsyncModel(llm.AsyncModel): + # This cn duplicate the model_id of the sync model: + model_id = "my-model-id" + + async def execute( + self, prompt, stream, response, conversation=None + ) -> AsyncGenerator[str, None]: + if stream: + completion = await client.chat.completions.create( + model=self.model_id, + messages=messages, + stream=True, + ) + async for chunk in completion: + yield chunk.choices[0].delta.content + else: + completion = await client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=False, + ) + yield completion.choices[0].message.content +``` +This async model instance should then be passed to the `register()` method in the `register_models()` plugin hook: + +```python +@hookimpl +def register_models(register): + register( + MyModel(), MyAsyncModel(), aliases=("my-model-aliases",) + ) +``` + (advanced-model-plugins-attachments)= ## Attachments for multi-modal models @@ -12,6 +58,8 @@ Models such as GPT-4o, Claude 3.5 Sonnet and Google's Gemini 1.5 are multi-modal LLM calls these **attachments**. Models can specify the types of attachments they accept and then implement special code in the `.execute()` method to handle them. +See {ref}`the Python attachments documentation ` for details on using attachments in the Python API. + ### Specifying attachment types A `Model` subclass can list the types of attachments it accepts by defining a `attachment_types` class attribute: diff --git a/docs/python-api.md b/docs/python-api.md index 55387d07..8d437a0d 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -112,6 +112,8 @@ The `response.text()` method described earlier does this for you - it runs throu If a response has been evaluated, `response.text()` will continue to return the same string. +(python-api-async)= + ## Async models Some plugins provide async versions of their supported models, suitable for use with Python [asyncio](https://docs.python.org/3/library/asyncio.html). From 66847153cb9fb9c22bfc51ad5a46511dd0a9170a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Nov 2024 12:30:14 -0800 Subject: [PATCH 24/28] duration_ms ANY to pass test --- tests/test_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index f4e15861..285fa476 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -179,7 +179,7 @@ def test_chat_options(mock_model, logs_db): "response": "Some text", "response_json": None, "conversation_id": ANY, - "duration_ms": 0, + "duration_ms": ANY, "datetime_utc": ANY, } ] From 52799217260c9f4259185cf24cbc3e723b15dc99 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Nov 2024 12:32:48 -0800 Subject: [PATCH 25/28] llm models --async option Refs https://github.com/simonw/llm/pull/613#issuecomment-2474724406 --- docs/help.md | 1 + llm/cli.py | 9 +++++++-- tests/test_llm.py | 8 ++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/docs/help.md b/docs/help.md index 0f39a7a1..9db540a3 100644 --- a/docs/help.md +++ b/docs/help.md @@ -323,6 +323,7 @@ Usage: llm models list [OPTIONS] Options: --options Show options for each model, if available + --async List async models --help Show this message and exit. ``` diff --git a/llm/cli.py b/llm/cli.py index 431a2490..5a9f20b4 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1015,14 +1015,19 @@ def models(): @click.option( "--options", is_flag=True, help="Show options for each model, if available" ) -def models_list(options): +@click.option("async_", "--async", is_flag=True, help="List async models") +def models_list(options, async_): "List available models" models_that_have_shown_options = set() for model_with_aliases in get_models_with_aliases(): + if async_ and not model_with_aliases.async_model: + continue extra = "" if model_with_aliases.aliases: extra = " (aliases: {})".format(", ".join(model_with_aliases.aliases)) - model = model_with_aliases.model + model = ( + model_with_aliases.model if not async_ else model_with_aliases.async_model + ) output = str(model) + extra if options and model.Options.schema()["properties"]: output += "\n Options:" diff --git a/tests/test_llm.py b/tests/test_llm.py index a0058713..0e54cc91 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -555,6 +555,14 @@ def test_llm_models_options(user_path): result = runner.invoke(cli, ["models", "--options"], catch_exceptions=False) assert result.exit_code == 0 assert EXPECTED_OPTIONS.strip() in result.output + assert "AsyncMockModel: mock" not in result.output + + +def test_llm_models_async(user_path): + runner = CliRunner() + result = runner.invoke(cli, ["models", "--async"], catch_exceptions=False) + assert result.exit_code == 0 + assert "AsyncMockModel: mock" in result.output def test_llm_user_dir(tmpdir, monkeypatch): From 63220402f71bdbb0a6cee7d0736dd3262724f148 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Nov 2024 13:31:47 -0800 Subject: [PATCH 26/28] Removed obsolete TypeVars --- llm/models.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/llm/models.py b/llm/models.py index 8539d083..82777edb 100644 --- a/llm/models.py +++ b/llm/models.py @@ -16,7 +16,6 @@ List, Optional, Set, - TypeVar, Union, ) from .utils import mimetype_from_path, mimetype_from_string @@ -25,13 +24,6 @@ from pydantic import BaseModel from ulid import ULID -ModelT = TypeVar("ModelT", bound=Union["Model", "AsyncModel"]) -ConversationT = TypeVar( - "ConversationT", bound=Optional[Union["Conversation", "AsyncConversation"]] -) -ResponseT = TypeVar("ResponseT") - - CONVERSATION_NAME_LENGTH = 32 From e677e2c9c5c7082e2d2f72c7795fd56e64774f49 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Nov 2024 16:48:14 -0800 Subject: [PATCH 27/28] Expanded register_models() docs for async --- docs/plugins/advanced-model-plugins.md | 3 +++ docs/plugins/plugin-hooks.md | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/docs/plugins/advanced-model-plugins.md b/docs/plugins/advanced-model-plugins.md index 26e53eb7..1793c751 100644 --- a/docs/plugins/advanced-model-plugins.md +++ b/docs/plugins/advanced-model-plugins.md @@ -5,6 +5,8 @@ The {ref}`model plugin tutorial ` covers the basics of de This document covers more advanced topics. +(advanced-model-plugins-async)= + ## Async models Plugins can optionally provide an asynchronous version of their model, suitable for use with Python [asyncio](https://docs.python.org/3/library/asyncio.html). This is particularly useful for remote models accessible by an HTTP API. @@ -52,6 +54,7 @@ def register_models(register): ``` (advanced-model-plugins-attachments)= + ## Attachments for multi-modal models Models such as GPT-4o, Claude 3.5 Sonnet and Google's Gemini 1.5 are multi-modal: they accept input in the form of images and maybe even audio, video and other formats. diff --git a/docs/plugins/plugin-hooks.md b/docs/plugins/plugin-hooks.md index 1d7d58f6..0f38cd64 100644 --- a/docs/plugins/plugin-hooks.md +++ b/docs/plugins/plugin-hooks.md @@ -42,5 +42,20 @@ class HelloWorld(llm.Model): def execute(self, prompt, stream, response): return ["hello world"] ``` +If your model includes an async version, you can register that too: + +```python +class AsyncHelloWorld(llm.AsyncModel): + model_id = "helloworld" + + async def execute(self, prompt, stream, response): + return ["hello world"] + +@llm.hookimpl +def register_models(register): + register(HelloWorld(), AsyncHelloWorld(), aliases=("hw",)) +``` +This demonstrates how to register a model with both sync and async versions, and how to specify an alias for that model. + +The {ref}`model plugin tutorial ` describes how to use this hook in detail. Asynchronous models {ref}`are described here `. -{ref}`tutorial-model-plugin` describes how to use this hook in detail. From cb2f1510b836f7bbdf30f6b94a2de8abab2293fe Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Nov 2024 17:33:31 -0800 Subject: [PATCH 28/28] await model.prompt() now returns AsyncResponse Refs https://github.com/simonw/llm/pull/613#issuecomment-2475157822 --- docs/python-api.md | 5 +++-- llm/__init__.py | 2 ++ llm/models.py | 3 ++- tests/test_async.py | 5 ++++- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/docs/python-api.md b/docs/python-api.md index 8d437a0d..0450031a 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -127,16 +127,17 @@ model = llm.get_async_model("gpt-4o") You can then run a prompt using `await model.prompt(...)`: ```python -result = await model.prompt( +response = await model.prompt( "Five surprising names for a pet pelican" ) +print(await response.text()) ``` Or use `async for chunk in ...` to stream the response as it is generated: ```python async for chunk in model.prompt( "Five surprising names for a pet pelican" ): - print(chunk, end="") + print(chunk, end="", flush=True) ``` ## Conversations diff --git a/llm/__init__.py b/llm/__init__.py index 16ff9aed..d6df280f 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -5,6 +5,7 @@ ) from .models import ( AsyncModel, + AsyncResponse, Attachment, Conversation, Model, @@ -31,6 +32,7 @@ "get_model", "get_key", "user_dir", + "AsyncResponse", "Attachment", "Collection", "Conversation", diff --git a/llm/models.py b/llm/models.py index 82777edb..cb9c7ab3 100644 --- a/llm/models.py +++ b/llm/models.py @@ -391,6 +391,7 @@ async def _force(self): if not self._done: async for _ in self: pass + return self async def text(self) -> str: await self._force() @@ -409,7 +410,7 @@ async def datetime_utc(self) -> str: return self._start_utcnow.isoformat() if self._start_utcnow else "" def __await__(self): - return self.text().__await__() + return self._force().__await__() @classmethod def fake( diff --git a/tests/test_async.py b/tests/test_async.py index db4cf529..a84dd97d 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,3 +1,4 @@ +import llm import pytest @@ -10,5 +11,7 @@ async def test_async_model(async_mock_model): assert gathered == ["hello world"] # Not as an iterator async_mock_model.enqueue(["hello world"]) - text = await async_mock_model.prompt("hello") + response = await async_mock_model.prompt("hello") + text = await response.text() assert text == "hello world" + assert isinstance(response, llm.AsyncResponse)