Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change default model for google stt and add aws llm test case #1552

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/brave-brooms-rest.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-google": patch
---

google stt: change default model to `latest_long`
6 changes: 6 additions & 0 deletions .changeset/soft-tips-care.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-plugins-anthropic": patch
"livekit-plugins-aws": patch
---

don't pass functions in params when tool choice is set to none
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def chat(

opts["tools"] = fncs_desc
if tool_choice is not None:
anthropic_tool_choice: dict[str, Any] = {"type": "auto"}
anthropic_tool_choice: dict[str, Any] | None = {"type": "auto"}
if isinstance(tool_choice, ToolChoice):
if tool_choice.type == "function":
anthropic_tool_choice = {
Expand All @@ -181,9 +181,13 @@ def chat(
elif isinstance(tool_choice, str):
if tool_choice == "required":
anthropic_tool_choice = {"type": "any"}
if parallel_tool_calls is not None and parallel_tool_calls is False:
anthropic_tool_choice["disable_parallel_tool_use"] = True
opts["tool_choice"] = anthropic_tool_choice
elif tool_choice == "none":
opts["tools"] = []
anthropic_tool_choice = None
if anthropic_tool_choice is not None:
if parallel_tool_calls is False:
anthropic_tool_choice["disable_parallel_tool_use"] = True
opts["tool_choice"] = anthropic_tool_choice

latest_system_message: anthropic.types.TextBlockParam = _latest_system_message(
chat_ctx, caching=self._opts.caching
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,13 @@ def _build_image(image: llm.ChatImage, cache_key: Any) -> dict:
height=image.inference_height,
strategy="scale_aspect_fit",
)
encoded_data = utils.images.encode(image.image, opts)
image._cache[cache_key] = base64.b64encode(encoded_data).decode("utf-8")
image._cache[cache_key] = utils.images.encode(image.image, opts)

return {
"image": {
"format": "jpeg",
"source": {
"bytes": image._cache[cache_key].encode("utf-8"),
"bytes": image._cache[cache_key],
},
}
}
Expand Down
61 changes: 43 additions & 18 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,27 +184,28 @@ async def _run(self) -> None:
try:
opts: dict[str, Any] = {}
messages, system_instruction = _build_aws_ctx(self._chat_ctx, id(self))
if messages[0]["role"] != "user":
messages.insert(
0,
{"role": "user", "content": [{"text": "(empty)"}]},
)
messages = _merge_messages(messages)

def _get_tool_config() -> dict[str, Any] | None:
if not (self._fnc_ctx and self._fnc_ctx.ai_functions):
return None

if self._fnc_ctx and self._fnc_ctx.ai_functions:
tools = _build_tools(self._fnc_ctx)
tool_config: dict[str, Any] = {"tools": tools}
config: dict[str, Any] = {"tools": tools}

if isinstance(self._tool_choice, ToolChoice):
tool_config["toolChoice"] = {
"tool": {"name": self._tool_choice.name}
}
config["toolChoice"] = {"tool": {"name": self._tool_choice.name}}
elif self._tool_choice == "required":
tool_config["toolChoice"] = {"any": {}}
config["toolChoice"] = {"any": {}}
elif self._tool_choice == "auto":
tool_config["toolChoice"] = {"auto": {}}
config["toolChoice"] = {"auto": {}}
else:
raise ValueError("aws bedrock llm: invalid tool choice")
return None

return config

tool_config = _get_tool_config()
if tool_config:
opts["toolConfig"] = tool_config

if self._additional_request_fields:
Expand All @@ -224,7 +225,7 @@ async def _run(self) -> None:
messages=messages,
system=[system_instruction],
inferenceConfig=inference_config,
**opts,
**_strip_nones(opts),
) # type: ignore

request_id = response["ResponseMetadata"]["RequestId"]
Expand Down Expand Up @@ -281,16 +282,16 @@ def _parse_chunk(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
return None

def _try_build_function(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
if not self._tool_call_id:
if self._tool_call_id is None:
logger.warning("aws bedrock llm: no tool call id in the response")
return None
if not self._fnc_name:
if self._fnc_name is None:
logger.warning("aws bedrock llm: no function name in the response")
return None
if not self._fnc_raw_arguments:
if self._fnc_raw_arguments is None:
logger.warning("aws bedrock llm: no function arguments in the response")
return None
if not self._fnc_ctx:
if self._fnc_ctx is None:
logger.warning(
"aws bedrock llm: stream tried to run function without function context"
)
Expand Down Expand Up @@ -320,5 +321,29 @@ def _try_build_function(self, request_id: str, chunk: dict) -> llm.ChatChunk | N
)


def _merge_messages(
messages: list[dict],
) -> list[dict]:
# Anthropic enforces alternating messages
combined_messages: list[dict] = []
for m in messages:
if len(combined_messages) == 0 or m["role"] != combined_messages[-1]["role"]:
combined_messages.append(m)
continue
last_message = combined_messages[-1]
if not isinstance(last_message["content"], list) or not isinstance(
m["content"], list
):
logger.error("message content is not a list")
continue

last_message["content"].extend(m["content"])

if len(combined_messages) == 0 or combined_messages[0]["role"] != "user":
combined_messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})

return combined_messages


def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in d.items() if v is not None}
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ def _build_gemini_image_part(image: llm.ChatImage, cache_key: Any) -> types.Part
height=image.inference_height,
strategy="scale_aspect_fit",
)
encoded_data = utils.images.encode(image.image, opts)
image._cache[cache_key] = base64.b64encode(encoded_data).decode("utf-8")
image._cache[cache_key] = utils.images.encode(image.image, opts)

return types.Part.from_bytes(
data=image._cache[cache_key], mime_type="image/jpeg"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"medical_conversation",
"chirp",
"chirp_2",
"latest_long",
"latest_short",
]

SpeechLanguages = Literal[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class STTOptions:
interim_results: bool
punctuate: bool
spoken_punctuation: bool
model: SpeechModels
model: SpeechModels | str
sample_rate: int
keywords: List[tuple[str, float]] | None

Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
interim_results: bool = True,
punctuate: bool = True,
spoken_punctuation: bool = False,
model: SpeechModels = "chirp_2",
model: SpeechModels | str = "latest_long",
location: str = "us-central1",
sample_rate: int = 16000,
credentials_info: dict | None = None,
Expand All @@ -106,6 +106,19 @@ def __init__(
Credentials must be provided, either by using the ``credentials_info`` dict, or reading
from the file specified in ``credentials_file`` or via Application Default Credentials as
described in https://cloud.google.com/docs/authentication/application-default-credentials

args:
languages(LanguageCode): list of language codes to recognize (default: "en-US")
detect_language(bool): whether to detect the language of the audio (default: True)
interim_results(bool): whether to return interim results (default: True)
punctuate(bool): whether to punctuate the audio (default: True)
spoken_punctuation(bool): whether to use spoken punctuation (default: False)
model(SpeechModels): the model to use for recognition default: "latest_long"
location(str): the location to use for recognition default: "us-central1"
sample_rate(int): the sample rate of the audio default: 16000
credentials_info(dict): the credentials info to use for recognition (default: None)
credentials_file(str): the credentials file to use for recognition (default: None)
keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
"""
super().__init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
Expand Down
8 changes: 3 additions & 5 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from livekit.agents import APIConnectionError, llm
from livekit.agents.llm import ChatContext, FunctionContext, TypeInfo, ai_callable
from livekit.plugins import anthropic, google, openai
from livekit.plugins import anthropic, aws, google, openai
from livekit.rtc import VideoBufferType, VideoFrame


Expand Down Expand Up @@ -101,7 +101,7 @@ def test_hashable_typeinfo():
pytest.param(lambda: anthropic.LLM(), id="anthropic"),
pytest.param(lambda: google.LLM(), id="google"),
pytest.param(lambda: google.LLM(vertexai=True), id="google-vertexai"),
# .param(lambda: aws.LLM(), id="aws"),
pytest.param(lambda: aws.LLM(), id="aws"),
]


Expand Down Expand Up @@ -348,9 +348,7 @@ async def test_tool_choice_options(
print(calls)

call_names = {call.call_info.function_info.name for call in calls}
if tool_choice == "none" and isinstance(input_llm, anthropic.LLM):
assert True
else:
if tool_choice == "none":
assert call_names == expected_calls, (
f"Test '{description}' failed: Expected calls {expected_calls}, but got {call_names}"
)
Expand Down
Loading