Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llm.get_async_model(), llm.AsyncModel base class and OpenAI async models #613

Merged
merged 30 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e26e7f7
First WIP prototype of async mode, refs #507
simonw Nov 6, 2024
1d8c3f8
Fix for llm hi --async --no-stream, refs #507
simonw Nov 6, 2024
b27b275
Fix for coroutine in __repr__
simonw Nov 6, 2024
44e6be1
register_model is now async aware
simonw Nov 6, 2024
2b6f5cc
Refactor Chat and AsyncChat to use _Shared base class
simonw Nov 6, 2024
7f6bea4
Merge branch 'main' into asyncio
simonw Nov 6, 2024
d9ed54f
fixed function name
simonw Nov 6, 2024
55830df
Fix for infinite loop
simonw Nov 7, 2024
5466a18
Applied Black
simonw Nov 7, 2024
3309528
Ran cog
github-actions[bot] Nov 7, 2024
d310df5
Applied Black
simonw Nov 7, 2024
61dfc1d
Add Response.from_row() classmethod back again
simonw Nov 7, 2024
b3a6ec7
Made mypy happy with llm/models.py
simonw Nov 7, 2024
91732d0
mypy fixes for openai_models.py
simonw Nov 7, 2024
2e1045d
First test for AsyncModel
simonw Nov 7, 2024
f311dbf
Still have not quite got this working
simonw Nov 8, 2024
4f3e82a
Fix for not loading plugins during tests, refs #626
simonw Nov 13, 2024
145b5cd
audio/wav not audio/wave, refs #603
simonw Nov 13, 2024
8ab5ea3
Black and mypy and ruff all happy
simonw Nov 13, 2024
9e82131
Merge branch 'main' into asyncio
simonw Nov 13, 2024
c4a7583
Refactor to avoid generics
simonw Nov 13, 2024
9b1e720
Removed obsolete response() method
simonw Nov 13, 2024
1c83a4e
Support text = await async_mock_model.prompt("hello")
simonw Nov 13, 2024
ceb60d2
Initial docs for llm.get_async_model() and await model.prompt()
simonw Nov 13, 2024
5f66149
Initial async model plugin creation docs
simonw Nov 13, 2024
6684715
duration_ms ANY to pass test
simonw Nov 13, 2024
5279921
llm models --async option
simonw Nov 13, 2024
6322040
Removed obsolete TypeVars
simonw Nov 13, 2024
e677e2c
Expanded register_models() docs for async
simonw Nov 14, 2024
cb2f151
await model.prompt() now returns AsyncResponse
simonw Nov 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/help.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Options:
--cid, --conversation TEXT Continue the conversation with the given ID.
--key TEXT API key to use
--save TEXT Save prompt with this template name
--async Run prompt asynchronously
--help Show this message and exit.
```

Expand Down
58 changes: 51 additions & 7 deletions llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
NeedsKeyException,
)
from .models import (
AsyncModel,
Attachment,
Conversation,
Model,
Expand All @@ -26,6 +27,7 @@

__all__ = [
"hookimpl",
"get_async_model",
"get_model",
"get_key",
"user_dir",
Expand Down Expand Up @@ -74,11 +76,11 @@ def get_models_with_aliases() -> List["ModelWithAliases"]:
for alias, model_id in configured_aliases.items():
extra_model_aliases.setdefault(model_id, []).append(alias)

def register(model, aliases=None):
def register(model, async_model=None, aliases=None):
alias_list = list(aliases or [])
if model.model_id in extra_model_aliases:
alias_list.extend(extra_model_aliases[model.model_id])
model_aliases.append(ModelWithAliases(model, alias_list))
model_aliases.append(ModelWithAliases(model, async_model, alias_list))

load_plugins()
pm.hook.register_models(register=register)
Expand Down Expand Up @@ -137,26 +139,68 @@ def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]:
return model_aliases


def get_async_model_aliases() -> Dict[str, AsyncModel]:
async_model_aliases = {}
for model_with_aliases in get_models_with_aliases():
if model_with_aliases.async_model:
for alias in model_with_aliases.aliases:
async_model_aliases[alias] = model_with_aliases.async_model
async_model_aliases[model_with_aliases.model.model_id] = (
model_with_aliases.async_model
)
return async_model_aliases


def get_model_aliases() -> Dict[str, Model]:
model_aliases = {}
for model_with_aliases in get_models_with_aliases():
for alias in model_with_aliases.aliases:
model_aliases[alias] = model_with_aliases.model
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
if model_with_aliases.model:
for alias in model_with_aliases.aliases:
model_aliases[alias] = model_with_aliases.model
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
return model_aliases


class UnknownModelError(KeyError):
pass


def get_model(name: Optional[str] = None) -> Model:
def get_async_model(name: Optional[str] = None) -> AsyncModel:
aliases = get_async_model_aliases()
name = name or get_default_model()
try:
return aliases[name]
except KeyError:
# Does a sync model exist?
sync_model = None
try:
sync_model = get_model(name, _skip_async=True)
except UnknownModelError:
pass
if sync_model:
raise UnknownModelError("Unknown async model (sync model exists): " + name)
else:
raise UnknownModelError("Unknown model: " + name)


def get_model(name: Optional[str] = None, _skip_async: bool = False) -> Model:
aliases = get_model_aliases()
name = name or get_default_model()
try:
return aliases[name]
except KeyError:
raise UnknownModelError("Unknown model: " + name)
# Does an async model exist?
if _skip_async:
raise UnknownModelError("Unknown model: " + name)
async_model = None
try:
async_model = get_async_model(name)
except UnknownModelError:
pass
if async_model:
raise UnknownModelError("Unknown model (async model exists): " + name)
else:
raise UnknownModelError("Unknown model: " + name)


def get_key(
Expand Down
64 changes: 48 additions & 16 deletions llm/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import click
from click_default_group import DefaultGroup
from dataclasses import asdict
Expand All @@ -11,6 +12,7 @@
Template,
UnknownModelError,
encode,
get_async_model,
get_default_model,
get_default_embedding_model,
get_embedding_models_with_aliases,
Expand All @@ -29,7 +31,7 @@
)

from .migrations import migrate
from .plugins import pm
from .plugins import pm, load_plugins
from .utils import mimetype_from_path, mimetype_from_string
import base64
import httpx
Expand Down Expand Up @@ -199,6 +201,7 @@ def cli():
)
@click.option("--key", help="API key to use")
@click.option("--save", help="Save prompt with this template name")
@click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously")
def prompt(
prompt,
system,
Expand All @@ -215,6 +218,7 @@ def prompt(
conversation_id,
key,
save,
async_,
):
"""
Execute a prompt
Expand Down Expand Up @@ -337,9 +341,12 @@ def read_prompt():

# Now resolve the model
try:
model = model_aliases[model_id]
except KeyError:
raise click.ClickException("'{}' is not a known model".format(model_id))
if async_:
model = get_async_model(model_id)
else:
model = get_model(model_id)
except UnknownModelError as ex:
raise click.ClickException(ex)

# Provide the API key, if one is needed and has been provided
if model.needs_key:
Expand Down Expand Up @@ -375,21 +382,48 @@ def read_prompt():
prompt_method = conversation.prompt

try:
response = prompt_method(
prompt, attachments=resolved_attachments, system=system, **validated_options
)
if should_stream:
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
if async_:

async def inner():
if should_stream:
async for chunk in prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
):
print(chunk, end="")
sys.stdout.flush()
print("")
else:
response = prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
)
print(await response.text())

asyncio.run(inner())
else:
print(response.text())
response = prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
)
if should_stream:
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
else:
print(response.text())
except Exception as ex:
raise click.ClickException(str(ex))

# Log to the database
if (logs_on() or log) and not no_log:
if (logs_on() or log) and not no_log and not async_:
log_path = logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
Expand Down Expand Up @@ -1810,8 +1844,6 @@ def render_errors(errors):
return "\n".join(output)


from .plugins import load_plugins

load_plugins()

pm.hook.register_commands(cli=cli)
Expand Down
Loading