diff --git a/.changeset/chatty-grapes-scream.md b/.changeset/chatty-grapes-scream.md deleted file mode 100644 index ffb96e558..000000000 --- a/.changeset/chatty-grapes-scream.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -"livekit-plugins-google": minor ---- - -Add support for google STT chirp_2 model. diff --git a/.changeset/famous-meals-sell.md b/.changeset/famous-meals-sell.md deleted file mode 100644 index 6569784e2..000000000 --- a/.changeset/famous-meals-sell.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -"livekit-plugins-openai": patch ---- - -project id fix for google diff --git a/.changeset/moody-snails-serve.md b/.changeset/moody-snails-serve.md deleted file mode 100644 index 61813e027..000000000 --- a/.changeset/moody-snails-serve.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -"livekit-plugins-google": patch ---- - -fix: add retry logic for google stt abort exception diff --git a/.changeset/tall-garlics-carry.md b/.changeset/tall-garlics-carry.md deleted file mode 100644 index bd2829f40..000000000 --- a/.changeset/tall-garlics-carry.md +++ /dev/null @@ -1,11 +0,0 @@ ---- -"livekit-agents": patch -"livekit-plugins-azure": patch -"livekit-plugins-cartesia": patch -"livekit-plugins-elevenlabs": patch -"livekit-plugins-google": patch -"livekit-plugins-openai": patch -"livekit-plugins-playht": patch ---- - -feat: tts retry & tts.FallbackAdapter diff --git a/.changeset/tough-boats-appear.md b/.changeset/tough-boats-appear.md deleted file mode 100644 index 80b77cd33..000000000 --- a/.changeset/tough-boats-appear.md +++ /dev/null @@ -1,6 +0,0 @@ ---- -"livekit-plugins-openai": patch -"livekit-agents": patch ---- - -Expose multimodal agent metrics diff --git a/.changeset/warm-zoos-lie.md b/.changeset/warm-zoos-lie.md deleted file mode 100644 index 06f55c2b7..000000000 --- a/.changeset/warm-zoos-lie.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -"livekit-plugins-google": patch ---- - -google STT - use the baseclass resampler diff --git a/.changeset/witty-months-train.md b/.changeset/witty-months-train.md deleted file mode 100644 index d309b9962..000000000 --- a/.changeset/witty-months-train.md +++ /dev/null @@ -1,6 +0,0 @@ ---- -"livekit-plugins-openai": patch -"livekit-agents": patch ---- - -vertex ai support with openai library diff --git a/.github/workflows/build-package.yml b/.github/workflows/build-package.yml index a305c9f4a..f0f721f72 100644 --- a/.github/workflows/build-package.yml +++ b/.github/workflows/build-package.yml @@ -23,22 +23,7 @@ on: jobs: build_plugins: runs-on: ubuntu-latest - if: | - inputs.package == 'livekit-agents' || - inputs.package == 'livekit-plugins-assemblyai' || - inputs.package == 'livekit-plugins-azure' || - inputs.package == 'livekit-plugins-cartesia' || - inputs.package == 'livekit-plugins-deepgram' || - inputs.package == 'livekit-plugins-elevenlabs' || - inputs.package == 'livekit-plugins-google' || - inputs.package == 'livekit-plugins-minimal' || - inputs.package == 'livekit-plugins-nltk' || - inputs.package == 'livekit-plugins-openai' || - inputs.package == 'livekit-plugins-rag' || - inputs.package == 'livekit-plugins-silero' || - inputs.package == 'livekit-plugins-anthropic' || - inputs.package == 'livekit-plugins-llama-index' - + if: inputs.package != 'livekit-plugins-browser' defaults: run: working-directory: "${{ startsWith(inputs.package, 'livekit-plugin') && 'livekit-plugins/' || '' }}${{ inputs.package }}" @@ -62,7 +47,7 @@ jobs: run: python -m build - name: Upload distribution package - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ inputs.artifact_name }} path: "${{ startsWith(inputs.package, 'livekit-plugin') && 'livekit-plugins/' || '' }}${{ inputs.package }}/dist/" @@ -97,7 +82,7 @@ jobs: CIBW_BUILD_VERBOSITY: 3 - name: Upload distribution package - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ inputs.artifact_name }} path: livekit-plugins/livekit-plugins-browser/dist/ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bda2441ef..9eb72c55c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -50,19 +50,8 @@ jobs: - name: Install all packages run: | - pip install ./livekit-agents \ - ./livekit-plugins/livekit-plugins-openai \ - ./livekit-plugins/livekit-plugins-deepgram \ - ./livekit-plugins/livekit-plugins-google \ - ./livekit-plugins/livekit-plugins-nltk \ - ./livekit-plugins/livekit-plugins-silero \ - ./livekit-plugins/livekit-plugins-elevenlabs \ - ./livekit-plugins/livekit-plugins-cartesia \ - ./livekit-plugins/livekit-plugins-rag \ - ./livekit-plugins/livekit-plugins-azure \ - ./livekit-plugins/livekit-plugins-anthropic \ - ./livekit-plugins/livekit-plugins-llama-index \ - ./livekit-plugins/livekit-plugins-fal + pip install ./livekit-agents + ./livekit-plugins/install_local.sh - name: Install stub packages run: | @@ -90,5 +79,6 @@ jobs: -p livekit.plugins.rag \ -p livekit.plugins.azure \ -p livekit.plugins.anthropic \ - -p livekit.plugins.fal - + -p livekit.plugins.fal \ + -p livekit.plugins.playai \ + -p livekit.plugins.assemblyai diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml index d6ced9dc8..934f54c90 100644 --- a/.github/workflows/publish-docs.yml +++ b/.github/workflows/publish-docs.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.12" # Specify the Python version you want to use + python-version: "3.12" - name: Create and activate virtual environment run: | @@ -34,18 +34,8 @@ jobs: - name: Install package run: | source venv/bin/activate - python -m pip install ./livekit-agents \ - ./livekit-plugins/livekit-plugins-anthropic \ - ./livekit-plugins/livekit-plugins-azure \ - ./livekit-plugins/livekit-plugins-cartesia \ - ./livekit-plugins/livekit-plugins-deepgram \ - ./livekit-plugins/livekit-plugins-elevenlabs \ - ./livekit-plugins/livekit-plugins-google \ - ./livekit-plugins/livekit-plugins-nltk \ - ./livekit-plugins/livekit-plugins-openai \ - ./livekit-plugins/livekit-plugins-rag \ - ./livekit-plugins/livekit-plugins-silero \ - ./livekit-plugins/livekit-plugins-llama-index + pip install ./livekit-agents + ./livekit-plugins/install_local.sh - name: Build Docs run: | @@ -60,3 +50,11 @@ jobs: AWS_ACCESS_KEY_ID: ${{ secrets.DOCS_DEPLOY_AWS_ACCESS_KEY }} AWS_SECRET_ACCESS_KEY: ${{ secrets.DOCS_DEPLOY_AWS_API_SECRET }} AWS_DEFAULT_REGION: "us-east-1" + + - name: Expire cloudfront cache + run: | + aws cloudfront create-invalidation --distribution-id EJJ40KLJ3TRY9 --paths "/python/*" + env: + AWS_ACCESS_KEY_ID: ${{ secrets.DOCS_DEPLOY_AWS_ACCESS_KEY }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.DOCS_DEPLOY_AWS_API_SECRET }} + AWS_DEFAULT_REGION: "us-east-1" diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index bd5b5b59f..61692429e 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -26,8 +26,7 @@ jobs: with: submodules: true lfs: true - env: - GITHUB_TOKEN: ${{ secrets.CHANGESETS_PUSH_PAT }} + ssh-key: ${{ secrets.CHANGESETS_PUSH_DEPLOY_KEY }} - uses: pnpm/action-setup@v4 - name: Use Node.js 20 @@ -84,7 +83,7 @@ jobs: uses: livekit/agents/.github/workflows/build-package.yml@main with: package: ${{ matrix.package.name }} - artifact_name: python-package-distributions + artifact_name: python-package-dist-${{matrix.package.name}} publish: needs: @@ -96,10 +95,11 @@ jobs: steps: - name: Download all the dists - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: - name: python-package-distributions - path: dist/ + path: dist + pattern: python-package-dist-* + merge-multiple: true - name: Publish package uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 @@ -108,9 +108,9 @@ jobs: password: ${{ secrets.PYPI_API_TOKEN }} trigger-docs-publish: - name: Publish Docs + name: Publish Docs needs: publish uses: ./.github/workflows/publish-docs.yml secrets: - DOCS_DEPLOY_AWS_ACCESS_KEY: ${{ secrets.DOCS_DEPLOY_AWS_ACCESS_KEY }} - DOCS_DEPLOY_AWS_API_SECRET: ${{ secrets.DOCS_DEPLOY_AWS_API_SECRET }} + DOCS_DEPLOY_AWS_ACCESS_KEY: ${{ secrets.DOCS_DEPLOY_AWS_ACCESS_KEY }} + DOCS_DEPLOY_AWS_API_SECRET: ${{ secrets.DOCS_DEPLOY_AWS_API_SECRET }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0c7790813..25f72cc33 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,19 +18,26 @@ jobs: strategy: fail-fast: false matrix: - os: [macos-14-large, macos-14, windows-2019, ubuntu-20.04, namespace-profile-default-arm64] + os: [ + # disabled Intel Macs due to pytorch 2.3+ not supporting it + # macos-14-large, + macos-14, + windows-2019, + ubuntu-20.04, + namespace-profile-default-arm64, + ] python_version: ["3.9", "3.12"] test_group: ["base"] include: # Include llm, stt, and tts tests only on Ubuntu 20.04 with Python 3.9 - os: ubuntu-20.04 - python_version: "3.9" + python_version: "3.12" test_group: llm - os: ubuntu-20.04 - python_version: "3.9" + python_version: "3.12" test_group: stt - os: ubuntu-20.04 - python_version: "3.9" + python_version: "3.12" test_group: tts runs-on: ${{ matrix.os }} @@ -54,7 +61,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: "3.9" + python-version: ${{ matrix.python_version }} cache: "pip" - name: Install ffmpeg (Linux) @@ -70,6 +77,7 @@ jobs: sudo dpkg -i libssl1.1_1.1.1-1ubuntu2.1_arm64.deb sudo dpkg -i libssl-dev_1.1.1-1ubuntu2.1_arm64.deb + - name: Install ffmpeg (macOS) if: ${{ startsWith(matrix.os, 'macos') }} run: brew install ffmpeg @@ -81,20 +89,9 @@ jobs: - name: Install packages shell: bash run: | - pip3 install pytest pytest-asyncio pytest-timeout './livekit-agents[codecs]' psutil - pip3 install -r ./tests/test-requirements.txt - pip3 install ./livekit-agents \ - ./livekit-plugins/livekit-plugins-openai \ - ./livekit-plugins/livekit-plugins-deepgram \ - ./livekit-plugins/livekit-plugins-google \ - ./livekit-plugins/livekit-plugins-nltk \ - ./livekit-plugins/livekit-plugins-silero \ - ./livekit-plugins/livekit-plugins-elevenlabs \ - ./livekit-plugins/livekit-plugins-cartesia \ - ./livekit-plugins/livekit-plugins-azure \ - ./livekit-plugins/livekit-plugins-anthropic \ - ./livekit-plugins/livekit-plugins-assemblyai \ - ./livekit-plugins/livekit-plugins-fal + pip install pytest pytest-asyncio pytest-timeout './livekit-agents[codecs]' psutil + pip install -r ./tests/test-requirements.txt + ./livekit-plugins/install_local.sh - name: Run tests shell: bash @@ -110,8 +107,11 @@ jobs: AZURE_SPEECH_REGION: ${{ secrets.AZURE_SPEECH_REGION }} # nit: doesn't have to be secret GOOGLE_CREDENTIALS_JSON: ${{ secrets.GOOGLE_CREDENTIALS_JSON }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} ASSEMBLYAI_API_KEY: ${{ secrets.ASSEMBLYAI_API_KEY }} FAL_KEY: ${{ secrets.FAL_KEY }} + PLAYHT_API_KEY: ${{ secrets.PLAYHT_API_KEY }} + PLAYHT_USER_ID: ${{ secrets.PLAYHT_USER_ID }} GOOGLE_APPLICATION_CREDENTIALS: google.json PYTEST_ADDOPTS: "--color=yes" working-directory: tests @@ -120,7 +120,7 @@ jobs: case "${{ matrix.test_group }}" in base) - test_files="test_aio.py test_tokenizer.py test_vad.py test_ipc.py test_fallback.py" + test_files="test_aio.py test_tokenizer.py test_vad.py test_ipc.py test_tts_fallback.py test_stt_fallback.py test_message_change.py test_build_func_desc.py test_create_func.py" ;; llm) test_files="test_llm.py" @@ -136,4 +136,4 @@ jobs: exit 1 ;; esac - pytest --asyncio-mode=auto --timeout=60 $test_files + pytest $test_files diff --git a/README.md b/README.md index 2b0ef510a..11664a68f 100644 --- a/README.md +++ b/README.md @@ -9,27 +9,33 @@ +

Looking for the JS/TS library? Check out [AgentsJS](https://github.com/livekit/agents-js) -## ✨ [NEW] OpenAI Realtime API support +## ✨ NEW ✨ -We are partnering with OpenAI on a new `MultimodalAgent` API in the Agents framework. This class completely wraps OpenAI’s Realtime API, abstracts away the raw wire protocol, and provide an ultra-low latency WebRTC transport between GPT-4o and your users’ devices. This same stack powers Advanced Voice in the ChatGPT app. +### Google Gemini 2.0 support -- Try the Realtime API in our [playground](https://playground.livekit.io/) [[code](https://github.com/livekit-examples/realtime-playground)] -- Check out our [guide](https://docs.livekit.io/agents/openai) to building your first app with this new API +Introducing support for the new Gemini 2.0 model. Here's an example voice agent running Google STT, TTS, and Gemini 2.0 Flash: [code](./examples/voice-pipeline-agent/gemini_voice_agent.py) + +### In-house phrase endpointing model + +We’ve trained a new, open weights phrase endpointing model that significantly improves end-of-turn detection and conversational flow between voice agents and users by reducing agent interruptions. Optimized to run on CPUs, it’s available via [livekit-plugins-turn-detector](https://pypi.org/project/livekit-plugins-turn-detector/) package. ## What is Agents? -The Agents framework allows you to build AI-driven server programs that can see, hear, and speak in realtime. Your agent connects with end user devices through a LiveKit session. During that session, your agent can process text, audio, images, or video streaming from a user's device, and have an AI model generate any combination of those same modalities as output, and stream them back to the user. +The **Agents framework** enables you to build AI-driven server programs that can see, hear, and speak in realtime. It offers a fully open-source platform for creating realtime, agentic applications. ## Features -- Plugins for popular LLMs, transcription and text-to-speech services, and RAG databases -- High-level abstractions for building voice agents or assistants with automatic turn detection, interruption handling, function calling, and transcriptions -- Compatible with LiveKit's [telephony stack](https://github.com/livekit/sip), allowing your agent to make calls to or receive calls from phones -- Integrated load balancing system that manages pools of agents with edge-based dispatch, monitoring, and transparent failover -- Running your agents is identical across localhost, [self-hosted](https://github.com/livekit/livekit), and [LiveKit Cloud](https://cloud.livekit.io) environments +- **Flexible integrations**: A comprehensive ecosystem to mix and match the right models for each use case. +- **AI voice agents**: `VoicePipelineAgent` and `MultimodalAgent` help orchestrate the conversation flow using LLMs and other AI models. +- **Integrated job scheduling**: Built-in task scheduling and distribution with [dispatch APIs](https://docs.livekit.io/agents/build/dispatch/) to connect end users to agents. +- **Realtime media transport**: Stream audio, video, and data over WebRTC and SIP with client SDKs for most platforms. +- **Telephony integration**: Works seamlessly with LiveKit's [telephony stack](https://docs.livekit.io/sip/), allowing your agent to make calls to or receive calls from phones. +- **Exchange data with clients**: Use [RPCs](https://docs.livekit.io/home/client/data/rpc/) and other [Data APIs](https://docs.livekit.io/home/client/data/) to seamlessly exchange data with clients. +- **Open-source**: Fully open-source, allowing you to run the entire stack on your own servers, including [LiveKit server](https://github.com/livekit/livekit), one of the most widely used WebRTC media servers. @@ -41,7 +47,7 @@ To install the core Agents library: pip install livekit-agents ``` -## Plugins +## Integrations The framework includes a variety of plugins that make it easy to process streaming input or generate output. For example, there are plugins for converting text-to-speech or running inference with popular LLMs. Here's how you can install a plugin: @@ -49,22 +55,60 @@ The framework includes a variety of plugins that make it easy to process streami pip install livekit-plugins-openai ``` -The following plugins are available today: - -| Plugin | Features | -| ---------------------------------------------------------------------------------- | ------------------------------------------- | -| [livekit-plugins-anthropic](https://pypi.org/project/livekit-plugins-anthropic/) | LLM | -| [livekit-plugins-assemblyai](https://pypi.org/project/livekit-plugins-assemblyai/) | STT | -| [livekit-plugins-azure](https://pypi.org/project/livekit-plugins-azure/) | STT, TTS | -| [livekit-plugins-deepgram](https://pypi.org/project/livekit-plugins-deepgram/) | STT | -| [livekit-plugins-cartesia](https://pypi.org/project/livekit-plugins-cartesia/) | TTS | -| [livekit-plugins-elevenlabs](https://pypi.org/project/livekit-plugins-elevenlabs/) | TTS | -| [livekit-plugins-playht](https://pypi.org/project/livekit-plugins-playht/) | TTS | -| [livekit-plugins-google](https://pypi.org/project/livekit-plugins-google/) | STT, TTS | -| [livekit-plugins-nltk](https://pypi.org/project/livekit-plugins-nltk/) | Utilities for working with text | -| [livekit-plugins-rag](https://pypi.org/project/livekit-plugins-rag/) | Utilities for performing RAG | -| [livekit-plugins-openai](https://pypi.org/project/livekit-plugins-openai/) | LLM, STT, TTS, Assistants API, Realtime API | -| [livekit-plugins-silero](https://pypi.org/project/livekit-plugins-silero/) | VAD | +### Realtime API + +We've partnered with OpenAI on a new `MultimodalAgent` API in the Agents framework. This class completely wraps OpenAI’s Realtime API, abstracts away the raw wire protocol, and provide an ultra-low latency WebRTC transport between GPT-4o and your users’ devices. This same stack powers Advanced Voice in the ChatGPT app. + +- Try the Realtime API in our [playground](https://playground.livekit.io/) [[code](https://github.com/livekit-examples/realtime-playground)] +- Check out our [guide](https://docs.livekit.io/agents/openai) to building your first app with this new API + +### LLM + +| Provider | Package | Usage | +| --------------- | ------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | +| OpenAI | livekit-plugins-openai | [openai.LLM()](https://docs.livekit.io/python/livekit/plugins/openai/index.html#livekit.plugins.openai.LLM) | +| Azure OpenAI | livekit-plugins-openai | [openai.LLM.with_azure()](https://docs.livekit.io/python/livekit/plugins/openai/index.html#livekit.plugins.openai.LLM.with_azure) | +| Anthropic | livekit-plugins-anthropic | [anthropic.LLM()](https://docs.livekit.io/python/livekit/plugins/anthropic/index.html#livekit.plugins.anthropic.LLM) | +| Google (Gemini) | livekit-plugins-openai | [openai.LLM.with_vertex()](https://docs.livekit.io/python/livekit/plugins/openai/#livekit.plugins.openai.LLM.with_vertex) | +| Cerebras | livekit-plugins-openai | [openai.LLM.with_cerebras()](https://docs.livekit.io/python/livekit/plugins/openai/#livekit.plugins.openai.LLM.with_cerebras) | +| Groq | livekit-plugins-openai | [openai.LLM.with_groq()](https://docs.livekit.io/python/livekit/plugins/openai/#livekit.plugins.openai.LLM.with_groq) | +| Ollama | livekit-plugins-openai | [openai.LLM.with_ollama()](https://docs.livekit.io/python/livekit/plugins/openai/#livekit.plugins.openai.LLM.with_ollama) | +| Perplexity | livekit-plugins-openai | [openai.LLM.with_perplexity()](https://docs.livekit.io/python/livekit/plugins/openai/#livekit.plugins.openai.LLM.with_perplexity) | +| Together.ai | livekit-plugins-openai | [openai.LLM.with_together()](https://docs.livekit.io/python/livekit/plugins/openai/#livekit.plugins.openai.LLM.with_together) | +| X.ai (Groq) | livekit-plugins-openai | [openai.LLM.with_x_ai()](https://docs.livekit.io/python/livekit/plugins/openai/#livekit.plugins.openai.LLM.with_x_ai) | + +### STT + +| Provider | Package | Streaming | Usage | +| ---------------- | -------------------------- | --------- | ----------------------------------------------------------------------------------------------------------------------- | +| Azure | livekit-plugins-azure | ✅ | [azure.STT()](https://docs.livekit.io/python/livekit/plugins/azure/index.html#livekit.plugins.azure.STT) | +| Deepgram | livekit-plugins-deepgram | ✅ | [deepgram.STT()](https://docs.livekit.io/python/livekit/plugins/deepgram/index.html#livekit.plugins.deepgram.STT) | +| OpenAI (Whisper) | livekit-plugins-openai | | [openai.STT()](https://docs.livekit.io/python/livekit/plugins/openai/index.html#livekit.plugins.openai.STT) | +| Google | livekit-plugins-google | ✅ | [google.STT()](https://docs.livekit.io/python/livekit/plugins/google/index.html#livekit.plugins.google.STT) | +| AssemblyAI | livekit-plugins-assemblyai | | [assemblyai.STT()](https://docs.livekit.io/python/livekit/plugins/assemblyai/index.html#livekit.plugins.assemblyai.STT) | +| Groq (Whisper) | livekit-plugins-openai | | [openai.STT.with_groq()](https://docs.livekit.io/python/livekit/plugins/openai/#livekit.plugins.openai.STT.with_groq) | +| FAL (Whizper) | livekit-plugins-fal | | [fal.STT()](https://docs.livekit.io/python/livekit/plugins/fal/index.html#livekit.plugins.fal.STT) | + +### TTS + +| Provider | Package | Streaming | Voice Cloning | Usage | +| ------------ | -------------------------- | --------- | ------------- | ----------------------------------------------------------------------------------------------------------------------- | +| Cartesia | livekit-plugins-cartesia | ✅ | ✅ | [cartesia.TTS()](https://docs.livekit.io/python/livekit/plugins/cartesia/index.html#livekit.plugins.cartesia.TTS) | +| ElevenLabs | livekit-plugins-elevenlabs | ✅ | ✅ | [elevenlabs.TTS()](https://docs.livekit.io/python/livekit/plugins/elevenlabs/index.html#livekit.plugins.elevenlabs.TTS) | +| OpenAI | livekit-plugins-openai | | | [openai.TTS()](https://docs.livekit.io/python/livekit/plugins/openai/index.html#livekit.plugins.openai.TTS) | +| Azure OpenAI | livekit-plugins-openai | | | [openai.TTS.with_azure()](https://docs.livekit.io/python/livekit/plugins/openai/#livekit.plugins.openai.TTS.with_azure) | +| Google | livekit-plugins-google | ✅ | ✅ | [google.TTS()](https://docs.livekit.io/python/livekit/plugins/google/index.html#livekit.plugins.google.TTS) | +| Deepgram | livekit-plugins-deepgram | ✅ | | [deepgram.TTS()](https://docs.livekit.io/python/livekit/plugins/deepgram/index.html#livekit.plugins.deepgram.TTS) | + +### Other plugins + +| Plugin | Description | +| ----------------------------- | ----------------------------------- | +| livekit-plugins-rag | Annoy based simple RAG | +| livekit-plugins-llama-index | RAG with LlamaIndex | +| livekit-plugins-nltk | Utilities for working with text | +| livekit-plugins-vad | Voice activity detection | +| livekit-plugins-turn-detector | Conversational turn detection model | ## Documentation and guides @@ -72,27 +116,30 @@ Documentation on the framework and how to use it can be found [here](https://doc ## Example agents -| Description | Demo Link | Code Link | -|---------------------------------------------------------------------------------------------------------------|-------------------------------------------------|-----------------------------------------------------------------------------------------------| -| A basic voice agent using a pipeline of STT, LLM, and TTS | [demo](https://kitt.livekit.io) | [code](https://github.com/livekit/agents/blob/main/examples/voice-pipeline-agent/minimal_assistant.py) | -| Voice agent using the new OpenAI Realtime API | [demo](https://playground.livekit.io) | [code](https://github.com/livekit-examples/realtime-playground) | -| Super fast voice agent using Cerebras hosted Llama 3.1 | [demo](https://cerebras.vercel.app) | [code](https://github.com/dsa/fast-voice-assistant/) | -| Voice agent using Cartesia's Sonic model | [demo](https://cartesia-assistant.vercel.app/) | N/A | -| Agent that looks up the current weather via function call | N/A | [code](https://github.com/livekit/agents/blob/main/examples/voice-pipeline-agent/function_calling_weather.py) | -| Voice agent that performs a RAG-based lookup | N/A | [code](https://github.com/livekit/agents/tree/main/examples/voice-pipeline-agent/simple-rag) | -| Video agent that publishes a stream of RGB frames | N/A | [code](https://github.com/livekit/agents/tree/main/examples/simple-color) | -| Transcription agent that generates text captions from a user's speech | N/A | [code](https://github.com/livekit/agents/tree/main/examples/speech-to-text) | -| A chat agent you can text who will respond back with generated speech | N/A | [code](https://github.com/livekit/agents/tree/main/examples/text-to-speech) | -| Localhost multi-agent conference call | N/A | [code](https://github.com/dsa/multi-agent-meeting) | -| Moderation agent that uses Hive to detect spam/abusive video | N/A | [code](https://github.com/dsa/livekit-agents/tree/main/hive-moderation-agent) | - +| Description | Demo Link | Code Link | +| --------------------------------------------------------------------- | ---------------------------------------------- | ------------------------------------------------------------------------------------------------------ | +| A basic voice agent using a pipeline of STT, LLM, and TTS | [demo](https://kitt.livekit.io) | [code](https://github.com/livekit/agents/blob/main/examples/voice-pipeline-agent/minimal_assistant.py) | +| Voice agent using the new OpenAI Realtime API | [demo](https://playground.livekit.io) | [code](https://github.com/livekit-examples/realtime-playground) | +| Super fast voice agent using Cerebras hosted Llama 3.1 | [demo](https://cerebras.vercel.app) | [code](https://github.com/dsa/fast-voice-assistant/) | +| Voice agent using Cartesia's Sonic model | [demo](https://cartesia-assistant.vercel.app/) | [code](https://github.com/livekit-examples/cartesia-voice-agent) | +| Agent that looks up the current weather via function call | N/A | [code](https://github.com/livekit/agents/blob/main/examples/voice-pipeline-agent/function_calling_weather.py) | +| Voice Agent using Gemini 2.0 Flash | N/A | [code](https://github.com/livekit-examples/voice-pipeline-agent/gemini_voice_agent.py) | +| Voice agent with custom turn-detection model | N/A | [code](https://github.com/livekit/agents/blob/main/examples/voice-pipeline-agent/turn_detector.py) | +| Voice agent that performs a RAG-based lookup | N/A | [code](https://github.com/livekit/agents/tree/main/examples/voice-pipeline-agent/simple-rag) | +| Video agent that publishes a stream of RGB frames | N/A | [code](https://github.com/livekit/agents/tree/main/examples/simple-color) | +| Transcription agent that generates text captions from a user's speech | N/A | [code](https://github.com/livekit/agents/tree/main/examples/speech-to-text) | +| A chat agent you can text who will respond back with generated speech | N/A | [code](https://github.com/livekit/agents/tree/main/examples/text-to-speech) | +| Localhost multi-agent conference call | N/A | [code](https://github.com/dsa/multi-agent-meeting) | +| Moderation agent that uses Hive to detect spam/abusive video | N/A | [code](https://github.com/dsa/livekit-agents/tree/main/hive-moderation-agent) | ## Contributing The Agents framework is under active development in a rapidly evolving field. We welcome and appreciate contributions of any kind, be it feedback, bugfixes, features, new plugins and tools, or better documentation. You can file issues under this repo, open a PR, or chat with us in LiveKit's [Slack community](https://livekit.io/join-slack). +
+ diff --git a/examples/conversation_persistor.py b/examples/conversation_persistor.py new file mode 100644 index 000000000..0d9909b63 --- /dev/null +++ b/examples/conversation_persistor.py @@ -0,0 +1,213 @@ +import asyncio +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import Union + +import aiofiles +from dotenv import load_dotenv +from livekit.agents import ( + AutoSubscribe, + JobContext, + WorkerOptions, + cli, + multimodal, + utils, +) +from livekit.agents.llm import ChatMessage +from livekit.agents.multimodal.multimodal_agent import EventTypes +from livekit.plugins import openai + + +@dataclass +class EventLog: + eventname: str | None + """name of recorded event""" + time: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + """time the event is recorded""" + + +@dataclass +class TranscriptionLog: + role: str | None + """role of the speaker""" + transcription: str | None + """transcription of speech""" + time: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + """time the event is recorded""" + + +class ConversationPersistor(utils.EventEmitter[EventTypes]): + def __init__( + self, + *, + model: multimodal.MultimodalAgent | None, + log: str | None, + transcriptions_only: bool = False, + ): + """ + Initializes a ConversationPersistor instance which records the events and transcriptions of a MultimodalAgent. + + Args: + model (multimodal.MultimodalAgent): an instance of a MultiModalAgent + log (str): name of the external file to record events in + transcriptions_only (bool): a boolean variable to determine if only transcriptions will be recorded, False by default + user_transcriptions (arr): list of user transcriptions + agent_transcriptions (arr): list of agent transcriptions + events (arr): list of all events + log_q (asyncio.Queue): a queue of EventLog and TranscriptionLog + + """ + super().__init__() + + self._model = model + self._log = log + self._transcriptions_only = transcriptions_only + + self._user_transcriptions = [] + self._agent_transcriptions = [] + self._events = [] + + self._log_q = asyncio.Queue[Union[EventLog, TranscriptionLog, None]]() + + @property + def log(self) -> str | None: + return self._log + + @property + def model(self) -> multimodal.MultimodalAgent | None: + return self._model + + @property + def user_transcriptions(self) -> dict: + return self._user_transcriptions + + @property + def agent_transcriptions(self) -> dict: + return self._agent_transcriptions + + @property + def events(self) -> dict: + return self._events + + @log.setter + def log(self, newlog: str | None) -> None: + self._log = newlog + + async def _main_atask(self) -> None: + # Writes to file asynchronously + while True: + log = await self._log_q.get() + + if log is None: + break + + async with aiofiles.open(self._log, "a") as file: + if type(log) is EventLog and not self._transcriptions_only: + self._events.append(log) + await file.write("\n" + log.time + " " + log.eventname) + + if type(log) is TranscriptionLog: + if log.role == "user": + self._user_transcriptions.append(log) + else: + self._agent_transcriptions.append(log) + + await file.write( + "\n" + log.time + " " + log.role + " " + log.transcription + ) + + async def aclose(self) -> None: + # Exits + self._log_q.put_nowait(None) + await self._main_task + + def start(self) -> None: + # Listens for emitted MultimodalAgent events + self._main_task = asyncio.create_task(self._main_atask()) + + @self._model.on("user_started_speaking") + def _user_started_speaking(): + event = EventLog(eventname="user_started_speaking") + self._log_q.put_nowait(event) + + @self._model.on("user_stopped_speaking") + def _user_stopped_speaking(): + event = EventLog(eventname="user_stopped_speaking") + self._log_q.put_nowait(event) + + @self._model.on("agent_started_speaking") + def _agent_started_speaking(): + event = EventLog(eventname="agent_started_speaking") + self._log_q.put_nowait(event) + + @self._model.on("agent_stopped_speaking") + def _agent_stopped_speaking(): + transcription = TranscriptionLog( + role="agent", + transcription=(self._model._playing_handle._tr_fwd.played_text)[1:], + ) + self._log_q.put_nowait(transcription) + + event = EventLog(eventname="agent_stopped_speaking") + self._log_q.put_nowait(event) + + @self._model.on("user_speech_committed") + def _user_speech_committed(user_msg: ChatMessage): + transcription = TranscriptionLog( + role="user", transcription=user_msg.content + ) + self._log_q.put_nowait(transcription) + + event = EventLog(eventname="user_speech_committed") + self._log_q.put_nowait(event) + + @self._model.on("agent_speech_committed") + def _agent_speech_committed(): + event = EventLog(eventname="agent_speech_committed") + self._log_q.put_nowait(event) + + @self._model.on("agent_speech_interrupted") + def _agent_speech_interrupted(): + event = EventLog(eventname="agent_speech_interrupted") + self._log_q.put_nowait(event) + + @self._model.on("function_calls_collected") + def _function_calls_collected(): + event = EventLog(eventname="function_calls_collected") + self._log_q.put_nowait(event) + + @self._model.on("function_calls_finished") + def _function_calls_finished(): + event = EventLog(eventname="function_calls_finished") + self._log_q.put_nowait(event) + + +load_dotenv() + +logger = logging.getLogger("my-worker") +logger.setLevel(logging.INFO) + + +async def entrypoint(ctx: JobContext): + agent = multimodal.MultimodalAgent( + model=openai.realtime.RealtimeModel( + voice="alloy", + temperature=0.8, + instructions="You are a helpful assistant.", + turn_detection=openai.realtime.ServerVadOptions( + threshold=0.6, prefix_padding_ms=200, silence_duration_ms=500 + ), + ), + ) + + cp = ConversationPersistor(model=agent, log="log.txt") + cp.start() + + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + participant = await ctx.wait_for_participant() + agent.start(ctx.room, participant) + + +if __name__ == "__main__": + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint)) diff --git a/examples/hive-moderation-agent/README.md b/examples/hive-moderation-agent/README.md new file mode 100644 index 000000000..8f48218bb --- /dev/null +++ b/examples/hive-moderation-agent/README.md @@ -0,0 +1,41 @@ +# LiveKit realtime moderation agent using Hive + +This is an agent that performs visual moderation of every participant's video in a room. It does this moderation using the Visual Content Moderation model from [Hive](https://thehive.ai) [[docs](https://docs.thehive.ai/docs/visual-content-moderation#visual-content-moderation)]. + +## Prerequisites + +Before running this agent, you'll need: + +1. A LiveKit Cloud project (or a self-hosted LiveKit server). +2. An API key from Hive to access the above mentioned model. + +## Configuration + +Currently, this agent is configured entirely from the `agent.py` source code and the environment. + +### Environment Variables + +| configuration | description | example value | +|---------------|-------------|---------------| +| `LIVEKIT_URL` | Your LiveKit URL | `wss://test-abc123de.livekit.cloud` | +| `LIVEKIT_API_KEY` | Your LiveKit API key | | +| `LIVEKIT_API_SECRET` | Your LiveKit API secret | | +| `HIVE_API_KEY` | The API key from Hive to access the `Visual Content Moderation` model | `abc1deFgHIjK23KLMNOp45QrsTuv6wx8` | + +### Code + +| configuration | description | example value | +|---------------|-------------|---------------| +| `MOD_FRAME_INTERVAL` | Minimum number of seconds to wait between frames | 5.0 | +| `HIVE_HEADERS` | The headers to send with every request to the Hive API | `{}` | +| `CONFIDENCE_THRESHOLD` | The minimum score Hive's moderation class must meet before it is considered a problem | 0.9 | + +## Running + +Run this code like you would any other [LiveKit agent](https://docs.livekit.io/agents/build/anatomy/#starting-the-worker): + +``` +python3 agent.py start +``` + +Once running, the agent will join all new LiveKit rooms by default and begin moderation. diff --git a/examples/hive-moderation-agent/agent.py b/examples/hive-moderation-agent/agent.py new file mode 100644 index 000000000..bf0b23b07 --- /dev/null +++ b/examples/hive-moderation-agent/agent.py @@ -0,0 +1,163 @@ +""" +LiveKit agent that connects to a room and performs visual moderation on the video +of all participants using the Visual Content Moderation model from Hive +(https://docs.thehive.ai/docs/visual-content-moderation#visual-content-moderation). + +The agent periodically sends a frame from the participant's video to Hive's API +for a moderation check. If the results of that check show a confidence score +of 0.9 or higher for any of the positive classes, it logs the result and adds a +message to the room's chat. This can easily be extended to take additional +actions like removing a participant or ending a livestream, etc. +""" + +import asyncio +import logging +import os +import time +from io import BytesIO + +import aiohttp +from dotenv import load_dotenv +from hive_data_classes import HiveResponse, from_dict +from livekit import agents, rtc +from PIL import Image + +load_dotenv() + +MOD_FRAME_INTERVAL = 5.0 # check 1 frame every 5 seconds +""" +How often to check a frame (in seconds) +""" + +HIVE_HEADERS = { + "Authorization": f"Token {os.getenv('HIVE_API_KEY')}", + "accept": "application/json", +} +""" +The default headers included with every request to thehive.ai +""" + +CONFIDENCE_THRESHOLD = 0.9 +""" +THe threshold level for scores returned by thehive.ai. See details in this doc: +https://docs.thehive.ai/docs/visual-content-moderation#choosing-thresholds-for-visual-moderation +""" + + +logger = logging.getLogger("hive-moderation-agent") +logger.setLevel(logging.INFO) + + +async def request_fnc(req: agents.JobRequest): + """ + The request handler for the agent. We use this to set the name of the + agent that is displayed to users + """ + # accept the job request and name the agent participant so users know what this is + await req.accept( + name="Moderator", + identity="hive-moderator", + ) + + +async def entrypoint(ctx: agents.JobContext): + """ + The entrypoint of the agent. This is called every time the moderator + agent joins a room. + """ + + # connect to the room and automatically subscribe to all participants' video + await ctx.connect(auto_subscribe=agents.AutoSubscribe.VIDEO_ONLY) + chat = rtc.ChatManager(ctx.room) + + @ctx.room.on("track_subscribed") + def on_track_subscribed( + track: rtc.Track, + _publication: rtc.TrackPublication, + participant: rtc.RemoteParticipant, + ): + """ + Event handler for video tracks. We automatically subscribe to all video + tracks when a participant joins the room. This event is triggered + once we have completed subscription to that video track. + This creates a backgrond task to process frames from each track + """ + asyncio.create_task(process_track(participant, track)) + + async def process_track(participant: rtc.RemoteParticipant, track: rtc.VideoTrack): + """ + This function is running in a background task once for each video track + (i.e., once for each participant). It handles processing a frame + from the video once every MOD_FRAME INTERVAL seconds. + """ + + video_stream = rtc.VideoStream(track) + last_processed_time = 0 + async for frame in video_stream: + current_time = time.time() + if (current_time - last_processed_time) >= MOD_FRAME_INTERVAL: + last_processed_time = current_time + await check_frame(participant, frame) + + async def check_frame(participant: rtc.RemoteParticipant, frame: rtc.VideoFrame): + """ + Uses thehive.ai API to check the frame for any classifications we care about + """ + + # get the current frame and convert to png format + argb_frame = frame.frame.convert(rtc.VideoBufferType.RGBA) + image = Image.frombytes( + "RGBA", (argb_frame.width, argb_frame.height), argb_frame.data + ) + buffer = BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) # reset buffer position to beginning after writing + + data = aiohttp.FormData() + data.add_field("image", buffer, filename="image.png", content_type="image/png") + + # submit the image to Hive + logger.info("submitting image to hive") + async with aiohttp.ClientSession() as session: + async with session.post( + "https://api.thehive.ai/api/v2/task/sync", + headers=HIVE_HEADERS, + data=data, + ) as response: + response.raise_for_status() + response_dict = await response.json() + hive_response: HiveResponse = from_dict(HiveResponse, response_dict) + if ( + hive_response.code == 200 + and len(hive_response.status) > 0 + and len(hive_response.status[0].response.output) > 0 + ): + results = hive_response.status[0].response.output[0].classes + # filter to anything with a confidence score > threshold + for mod_class in results: + if mod_class.class_[0:4] == "yes_": + # TODO: should also include "general_nsfw" class + if mod_class.score >= CONFIDENCE_THRESHOLD: + class_name = mod_class.class_[4:] + message = ( + 'FOUND %s for participant "%s" (confidence score: %0.3f)' + % ( + class_name, + participant.identity, + mod_class.score, + ) + ) + logger.info(message) + await chat.send_message(message) + + await ctx.wait_for_participant() + await chat.send_message( + "I'm a moderation agent," + "I will detect and notify you of all inappropriate material in your video stream" + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + agents.cli.run_app(agents.WorkerOptions(entrypoint, request_fnc=request_fnc)) diff --git a/examples/hive-moderation-agent/hive_data_classes.py b/examples/hive-moderation-agent/hive_data_classes.py new file mode 100644 index 000000000..a1773435d --- /dev/null +++ b/examples/hive-moderation-agent/hive_data_classes.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass, is_dataclass +from typing import List, get_type_hints + + +def from_dict(cls, data): + if is_dataclass(cls) and isinstance(data, dict): + # Get type hints for all fields in the dataclass + field_types = get_type_hints(cls) + # Special handling for reserved words like 'class' + reserved_word_mappings = {"class": "class_"} # Map 'class' to 'class_' + processed_data = {} + for key, value in data.items(): + # Check if the key is a reserved word and map it accordingly + field_name = reserved_word_mappings.get(key, key) + # Only include keys that have corresponding fields in the dataclass + if field_name in field_types: + field_type = field_types[field_name] + # Determine if the field_type is itself a dataclass + if is_dataclass(field_type): + processed_value = from_dict(field_type, value) + elif hasattr(field_type, "__origin__") and issubclass( + field_type.__origin__, List + ): + # Handle List fields, assuming all elements are of the same type + item_type = field_type.__args__[0] + processed_value = [from_dict(item_type, item) for item in value] + else: + processed_value = value + processed_data[field_name] = processed_value + return cls(**processed_data) + elif isinstance(data, list): + # This assumes that the function was called with a list type as `cls`, + # which might not work as expected without context on the list's element type. + # A better approach might be needed for handling lists of dataclasses. + return [ + from_dict(cls.__args__[0], item) if hasattr(cls, "__args__") else item + for item in data + ] + else: + return data + + +@dataclass +class Status: + code: str + message: str + + +@dataclass +class ModInput: + id: str + charge: float + config_tag: SyntaxWarning + config_version: float + created_on: str + model: str + model_type: str + model_version: float + project_id: int + user_id: int + + +@dataclass +class ModClass: + class_: str + score: float + + +@dataclass +class ModOutput: + time: int + classes: List[ModClass] + + +@dataclass +class Response: + input: ModInput + output: List[ModOutput] + + +@dataclass +class ModResponse: + status: Status + response: Response + + +@dataclass +class HiveResponse: + id: str + code: int + project_id: int + user_id: int + created_on: str + status: List[ModResponse] + from_cache: bool diff --git a/examples/hive-moderation-agent/requirements.txt b/examples/hive-moderation-agent/requirements.txt new file mode 100644 index 000000000..517a8283f --- /dev/null +++ b/examples/hive-moderation-agent/requirements.txt @@ -0,0 +1,5 @@ +livekit +livekit-agents +python-dotenv +Pillow +aiohttp \ No newline at end of file diff --git a/examples/multimodal-agent/gemini_agent.py b/examples/multimodal-agent/gemini_agent.py new file mode 100644 index 000000000..81a474609 --- /dev/null +++ b/examples/multimodal-agent/gemini_agent.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import logging +from typing import Annotated + +import aiohttp +from dotenv import load_dotenv +from livekit.agents import ( + AutoSubscribe, + JobContext, + WorkerOptions, + WorkerType, + cli, + llm, + multimodal, +) +from livekit.plugins import google + +load_dotenv() + +logger = logging.getLogger("my-worker") +logger.setLevel(logging.INFO) + + +async def entrypoint(ctx: JobContext): + logger.info("starting entrypoint") + + fnc_ctx = llm.FunctionContext() + + @fnc_ctx.ai_callable() + async def get_weather( + location: Annotated[ + str, llm.TypeInfo(description="The location to get the weather for") + ], + ): + """Called when the user asks about the weather. This function will return the weather for the given location.""" + logger.info(f"getting weather for {location}") + url = f"https://wttr.in/{location}?format=%C+%t" + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + weather_data = await response.text() + # # response from the function call is returned to the LLM + return f"The weather in {location} is {weather_data}." + else: + raise Exception( + f"Failed to get weather data, status code: {response.status}" + ) + + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + participant = await ctx.wait_for_participant() + + chat_ctx = llm.ChatContext() + + agent = multimodal.MultimodalAgent( + model=google.beta.realtime.RealtimeModel( + voice="Charon", + temperature=0.8, + instructions="You are a helpful assistant", + ), + fnc_ctx=fnc_ctx, + chat_ctx=chat_ctx, + ) + agent.start(ctx.room, participant) + + +if __name__ == "__main__": + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM)) diff --git a/examples/multimodal_agent.py b/examples/multimodal-agent/openai_agent.py similarity index 100% rename from examples/multimodal_agent.py rename to examples/multimodal-agent/openai_agent.py diff --git a/examples/participant-entrypoint/requirements.txt b/examples/participant-entrypoint/requirements.txt index f1a7906e6..77c8959d1 100644 --- a/examples/participant-entrypoint/requirements.txt +++ b/examples/participant-entrypoint/requirements.txt @@ -1,2 +1,2 @@ -livekit-agents>=0.11.3 +livekit-agents>=0.12.6 python-dotenv~=1.0 diff --git a/examples/simple-color/requirements.txt b/examples/simple-color/requirements.txt index f1a7906e6..77c8959d1 100644 --- a/examples/simple-color/requirements.txt +++ b/examples/simple-color/requirements.txt @@ -1,2 +1,2 @@ -livekit-agents>=0.11.3 +livekit-agents>=0.12.6 python-dotenv~=1.0 diff --git a/examples/speech-to-text/requirements.txt b/examples/speech-to-text/requirements.txt index 5e8abfa56..b9f8e9fb0 100644 --- a/examples/speech-to-text/requirements.txt +++ b/examples/speech-to-text/requirements.txt @@ -1,3 +1,3 @@ -livekit-agents>=0.11.3 -livekit-plugins-deepgram>=0.6.11 +livekit-agents>=0.12.6 +livekit-plugins-deepgram>=0.6.16 python-dotenv~=1.0 diff --git a/examples/text-to-speech/requirements.txt b/examples/text-to-speech/requirements.txt index 3411c8661..f03f7fa49 100644 --- a/examples/text-to-speech/requirements.txt +++ b/examples/text-to-speech/requirements.txt @@ -1,5 +1,5 @@ -livekit-agents>=0.11.3 -livekit-plugins-openai>=0.10.7 -livekit-plugins-cartesia>=0.4.3 -livekit-plugins-elevenlabs>=0.7.7 +livekit-agents>=0.12.6 +livekit-plugins-openai>=0.10.13 +livekit-plugins-cartesia>=0.4.5 +livekit-plugins-elevenlabs>=0.7.9 python-dotenv~=1.0 diff --git a/examples/voice-pipeline-agent/README.md b/examples/voice-pipeline-agent/README.md index 6f7e176fb..a8fb69410 100644 --- a/examples/voice-pipeline-agent/README.md +++ b/examples/voice-pipeline-agent/README.md @@ -34,7 +34,10 @@ export OPENAI_API_KEY= ### Install requirments: -`pip install -r requirements.txt` +``` +pip install -r requirements.txt +python minimal_assistant.py download-files +``` ### Run the agent worker: diff --git a/examples/voice-pipeline-agent/fallback_adapter.py b/examples/voice-pipeline-agent/fallback_adapter.py new file mode 100644 index 000000000..ff171e939 --- /dev/null +++ b/examples/voice-pipeline-agent/fallback_adapter.py @@ -0,0 +1,86 @@ +import logging + +from dotenv import load_dotenv +from livekit.agents import ( + AutoSubscribe, + JobContext, + JobProcess, + WorkerOptions, + cli, + llm, + stt, + tts, +) +from livekit.agents.pipeline import VoicePipelineAgent +from livekit.plugins import cartesia, deepgram, elevenlabs, openai, silero + +load_dotenv() +logger = logging.getLogger("fallback-adapter-example") + + +def prewarm(proc: JobProcess): + proc.userdata["vad"] = silero.VAD.load() + + +async def entrypoint(ctx: JobContext): + initial_ctx = llm.ChatContext().append( + role="system", + text=( + "You are a voice assistant created by LiveKit. Your interface with users will be voice. " + "You should use short and concise responses, and avoiding usage of unpronouncable punctuation." + ), + ) + + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + + # wait for the first participant to connect + participant = await ctx.wait_for_participant() + logger.info(f"starting voice assistant for participant {participant.identity}") + + vad: silero.VAD = ctx.proc.userdata["vad"] + + # fallback to OpenAI if Deepgram goes down + fallback_stt = stt.FallbackAdapter( + [ + deepgram.STT(), + stt.StreamAdapter(stt=openai.STT(), vad=vad), + ] + ) + + # fallback to Azure if OpenAI goes down + fallback_llm = llm.FallbackAdapter( + [ + openai.LLM(), + openai.LLM.with_azure(), + ] + ) + + # fallback to 11labs if Cartesia goes down + # you can keep the same voice by using their voice cloning feature + fallback_tts = tts.FallbackAdapter( + [ + cartesia.TTS(), + elevenlabs.TTS(), + ] + ) + + agent = VoicePipelineAgent( + vad=vad, + stt=fallback_stt, + llm=fallback_llm, + tts=fallback_tts, + chat_ctx=initial_ctx, + ) + + agent.start(ctx.room, participant) + + await agent.say("Hey, how can I help you today?", allow_interruptions=True) + + +if __name__ == "__main__": + cli.run_app( + WorkerOptions( + entrypoint_fnc=entrypoint, + prewarm_fnc=prewarm, + ), + ) diff --git a/examples/voice-pipeline-agent/function_calling_weather.py b/examples/voice-pipeline-agent/function_calling_weather.py index f5bc3135b..f39705f17 100644 --- a/examples/voice-pipeline-agent/function_calling_weather.py +++ b/examples/voice-pipeline-agent/function_calling_weather.py @@ -1,4 +1,7 @@ import logging +import random +import re +import urllib from typing import Annotated import aiohttp @@ -11,7 +14,7 @@ cli, llm, ) -from livekit.agents.pipeline import VoicePipelineAgent +from livekit.agents.pipeline import AgentCallContext, VoicePipelineAgent from livekit.plugins import deepgram, openai, silero load_dotenv() @@ -33,19 +36,53 @@ async def get_weather( ], ): """Called when the user asks about the weather. This function will return the weather for the given location.""" + # Clean the location string of special characters + location = re.sub(r"[^a-zA-Z0-9]+", " ", location).strip() + + # When a function call is running, there are a couple of options to inform the user + # that it might take awhile: + # Option 1: you can use .say filler message immediately after the call is triggered + # Option 2: you can prompt the agent to return a text response when it's making a function call + agent = AgentCallContext.get_current().agent + + if ( + not agent.chat_ctx.messages + or agent.chat_ctx.messages[-1].role != "assistant" + ): + # skip if assistant already said something + filler_messages = [ + "Let me check the weather in {location} for you.", + "Let me see what the weather is like in {location} right now.", + # LLM will complete this sentence if it is added to the end of the chat context + "The current weather in {location} is ", + ] + message = random.choice(filler_messages).format(location=location) + logger.info(f"saying filler message: {message}") + + # NOTE: set add_to_chat_ctx=True will add the message to the end + # of the chat context of the function call for answer synthesis + speech_handle = await agent.say(message, add_to_chat_ctx=True) # noqa: F841 + logger.info(f"getting weather for {location}") - url = f"https://wttr.in/{location}?format=%C+%t" + url = f"https://wttr.in/{urllib.parse.quote(location)}?format=%C+%t" + weather_data = "" async with aiohttp.ClientSession() as session: async with session.get(url) as response: if response.status == 200: - weather_data = await response.text() # response from the function call is returned to the LLM - return f"The weather in {location} is {weather_data}." + weather_data = ( + f"The weather in {location} is {await response.text()}." + ) + logger.info(f"weather data: {weather_data}") else: raise Exception( f"Failed to get weather data, status code: {response.status}" ) + # (optional) To wait for the speech to finish before giving results of the function call + # await speech_handle.join() + return weather_data + def prewarm_process(proc: JobProcess): # preload silero VAD in memory to speed up session start @@ -58,7 +95,11 @@ async def entrypoint(ctx: JobContext): initial_chat_ctx = llm.ChatContext().append( text=( "You are a weather assistant created by LiveKit. Your interface with users will be voice. " - "You will provide weather information for a given location." + "You will provide weather information for a given location. " + # when using option 1, you can suppress from the agent with prompt + "do not return any text while calling the function." + # uncomment this to use option 2 + # "when performing function calls, let user know that you are checking the weather." ), role="system", ) @@ -71,6 +112,7 @@ async def entrypoint(ctx: JobContext): fnc_ctx=fnc_ctx, chat_ctx=initial_chat_ctx, ) + # Start the assistant. This will automatically publish a microphone track and listen to the participant. agent.start(ctx.room, participant) await agent.say( diff --git a/examples/voice-pipeline-agent/gemini_voice_agent.py b/examples/voice-pipeline-agent/gemini_voice_agent.py new file mode 100644 index 000000000..bb3641c6b --- /dev/null +++ b/examples/voice-pipeline-agent/gemini_voice_agent.py @@ -0,0 +1,90 @@ +import logging + +from dotenv import load_dotenv +from livekit.agents import ( + AutoSubscribe, + JobContext, + JobProcess, + WorkerOptions, + cli, + llm, + metrics, +) +from livekit.agents.pipeline import VoicePipelineAgent +from livekit.plugins import google, openai, silero + +load_dotenv() +logger = logging.getLogger("voice-assistant") + + +def prewarm(proc: JobProcess): + proc.userdata["vad"] = silero.VAD.load() + + +# An example Voice Agent using Google STT, Gemini 2.0 Flash, and Google TTS. +# Prerequisites: +# 1. livekit-plugins-openai[vertex] package installed +# 2. save your service account credentials and set the following environments: +# * GOOGLE_APPLICATION_CREDENTIALS to the path of the service account key file +# * GOOGLE_CLOUD_PROJECT to your Google Cloud project ID +# 3. the following services are enabled on your Google Cloud project: +# * Vertex AI +# * Cloud Speech-to-Text API +# * Cloud Text-to-Speech API + +# Read more about authentication with Google: https://cloud.google.com/docs/authentication/application-default-credentials + + +async def entrypoint(ctx: JobContext): + initial_ctx = llm.ChatContext().append( + role="system", + text=( + "You are a voice assistant created by LiveKit. Your interface with users will be voice. " + "You should use short and concise responses, and avoiding usage of unpronouncable punctuation." + ), + ) + + logger.info(f"connecting to room {ctx.room.name}") + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + + # wait for the first participant to connect + participant = await ctx.wait_for_participant() + logger.info(f"starting voice assistant for participant {participant.identity}") + + agent = VoicePipelineAgent( + vad=ctx.proc.userdata["vad"], + stt=google.STT(), + llm=openai.LLM.with_vertex(model="google/gemini-2.0-flash-exp"), + tts=google.TTS( + voice_name="en-US-Journey-D", + ), + chat_ctx=initial_ctx, + ) + + agent.start(ctx.room, participant) + + usage_collector = metrics.UsageCollector() + + @agent.on("metrics_collected") + def _on_metrics_collected(mtrcs: metrics.AgentMetrics): + metrics.log_metrics(mtrcs) + usage_collector.collect(mtrcs) + + async def log_usage(): + summary = usage_collector.get_summary() + logger.info(f"Usage: ${summary}") + + ctx.add_shutdown_callback(log_usage) + + await agent.say( + "Hi there, this is Gemini, how can I help you today?", allow_interruptions=False + ) + + +if __name__ == "__main__": + cli.run_app( + WorkerOptions( + entrypoint_fnc=entrypoint, + prewarm_fnc=prewarm, + ), + ) diff --git a/examples/voice-pipeline-agent/requirements.txt b/examples/voice-pipeline-agent/requirements.txt index 480526b8e..cf97c8314 100644 --- a/examples/voice-pipeline-agent/requirements.txt +++ b/examples/voice-pipeline-agent/requirements.txt @@ -1,7 +1,8 @@ -livekit-agents>=0.11.3 -livekit-plugins-openai>=0.10.7 -livekit-plugins-deepgram>=0.6.11 -livekit-plugins-silero>=0.7.3 -livekit-plugins-rag>=0.2.2 +livekit-agents>=0.12.6 +livekit-plugins-deepgram>=0.6.16 +livekit-plugins-google>=0.9.0 +livekit-plugins-openai[vertex]>=0.10.10 +livekit-plugins-silero>=0.7.4 +livekit-plugins-rag>=0.2.3 python-dotenv~=1.0 aiofile~=3.8.8 diff --git a/examples/voice-pipeline-agent/turn_detector.py b/examples/voice-pipeline-agent/turn_detector.py new file mode 100644 index 000000000..898ac9cc3 --- /dev/null +++ b/examples/voice-pipeline-agent/turn_detector.py @@ -0,0 +1,76 @@ +import logging + +from dotenv import load_dotenv +from livekit.agents import ( + AutoSubscribe, + JobContext, + JobProcess, + WorkerOptions, + cli, + llm, + metrics, +) +from livekit.agents.pipeline import VoicePipelineAgent +from livekit.plugins import deepgram, openai, silero, turn_detector + +load_dotenv() +logger = logging.getLogger("voice-assistant") + + +def prewarm(proc: JobProcess): + proc.userdata["vad"] = silero.VAD.load() + + +# This example uses our open-weight turn detection model to detect when the user is +# done speaking. This approach is more accurate than the default VAD model, reducing +# false positive interruptions by the agent. +async def entrypoint(ctx: JobContext): + initial_ctx = llm.ChatContext().append( + role="system", + text=( + "You are a voice assistant created by LiveKit. Your interface with users will be voice. " + "You should use short and concise responses, and avoiding usage of unpronouncable punctuation." + ), + ) + + logger.info(f"connecting to room {ctx.room.name}") + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + + # wait for the first participant to connect + participant = await ctx.wait_for_participant() + logger.info(f"starting voice assistant for participant {participant.identity}") + + agent = VoicePipelineAgent( + vad=ctx.proc.userdata["vad"], + stt=deepgram.STT(), + llm=openai.LLM(), + tts=openai.TTS(), + chat_ctx=initial_ctx, + turn_detector=turn_detector.EOUModel(), + ) + + agent.start(ctx.room, participant) + + usage_collector = metrics.UsageCollector() + + @agent.on("metrics_collected") + def _on_metrics_collected(mtrcs: metrics.AgentMetrics): + metrics.log_metrics(mtrcs) + usage_collector.collect(mtrcs) + + async def log_usage(): + summary = usage_collector.get_summary() + logger.info(f"Usage: ${summary}") + + ctx.add_shutdown_callback(log_usage) + + await agent.say("Hey, how can I help you today?", allow_interruptions=True) + + +if __name__ == "__main__": + cli.run_app( + WorkerOptions( + entrypoint_fnc=entrypoint, + prewarm_fnc=prewarm, + ), + ) diff --git a/livekit-agents/CHANGELOG.md b/livekit-agents/CHANGELOG.md index 7d0c2450f..d9c3770d4 100644 --- a/livekit-agents/CHANGELOG.md +++ b/livekit-agents/CHANGELOG.md @@ -1,5 +1,113 @@ # livekit-agents +## 0.12.6 + +### Patch Changes + +- expose worker_id in jobcontext - [#1307](https://github.com/livekit/agents/pull/1307) ([@s-hamdananwar](https://github.com/s-hamdananwar)) + +- improved handling of LLM errors, do not retry if already began - [#1298](https://github.com/livekit/agents/pull/1298) ([@davidzhao](https://github.com/davidzhao)) + +- Do not pass function context if at max depth - [#1306](https://github.com/livekit/agents/pull/1306) ([@martin-purplefish](https://github.com/martin-purplefish)) + +- avoid warnings when function depth matches limit - [#1316](https://github.com/livekit/agents/pull/1316) ([@davidzhao](https://github.com/davidzhao)) + +- improve interruption handling, avoid agent from getting stuck - [#1290](https://github.com/livekit/agents/pull/1290) ([@davidzhao](https://github.com/davidzhao)) + +- add manual interrupt method for pipeline agent - [#1294](https://github.com/livekit/agents/pull/1294) ([@longcw](https://github.com/longcw)) + +- make multimodal class generic and support gemini live api - [#1240](https://github.com/livekit/agents/pull/1240) ([@jayeshp19](https://github.com/jayeshp19)) + +## 0.12.5 + +### Patch Changes + +- make max_endpoint_delay configurable - [#1277](https://github.com/livekit/agents/pull/1277) ([@davidzhao](https://github.com/davidzhao)) + +- set USE_DOCSTRING as default for ai_callable - [#1266](https://github.com/livekit/agents/pull/1266) ([@longcw](https://github.com/longcw)) + +- fix: do not log process warning when process not found - [#1281](https://github.com/livekit/agents/pull/1281) ([@davidzhao](https://github.com/davidzhao)) + +- fix context when functions have been called - [#1279](https://github.com/livekit/agents/pull/1279) ([@jmugicagonz](https://github.com/jmugicagonz)) + +## 0.12.4 + +### Patch Changes + +- avoid duplicated chat ctx for function calls with messages - [#1254](https://github.com/livekit/agents/pull/1254) ([@longcw](https://github.com/longcw)) + +## 0.12.3 + +### Patch Changes + +- Moved create_ai_function_info to function_context.py for better reusability and reduce repetation - [#1260](https://github.com/livekit/agents/pull/1260) ([@jayeshp19](https://github.com/jayeshp19)) + +- added streaming audio decoder for compressed audio. - [#1236](https://github.com/livekit/agents/pull/1236) ([@davidzhao](https://github.com/davidzhao)) + +- Add JPEG quality param to image encoder - [#1249](https://github.com/livekit/agents/pull/1249) ([@bcherry](https://github.com/bcherry)) + +- Add support for OpenAI's "detail" parameter to ChatImage - [#1213](https://github.com/livekit/agents/pull/1213) ([@bcherry](https://github.com/bcherry)) + + Add support for data URLs on ChatImage in the Anthropic plugin. + +- fix: correctly parse function argument types - [#1221](https://github.com/livekit/agents/pull/1221) ([@jayeshp19](https://github.com/jayeshp19)) + +- Fix center_aspect_fit bug, add scale_aspect_fit and scale_aspect_fill resizing options. - [#1222](https://github.com/livekit/agents/pull/1222) ([@bcherry](https://github.com/bcherry)) + + Make scale_aspect_fit the new default resizing option for video frames. + +## 0.12.2 + +### Patch Changes + +- improvements to endpointing latency - [#1212](https://github.com/livekit/agents/pull/1212) ([@davidzhao](https://github.com/davidzhao)) + +- Improvements to end of turn plugin, ensure STT language settings. - [#1195](https://github.com/livekit/agents/pull/1195) ([@davidzhao](https://github.com/davidzhao)) + +- fix duplicated agent speech commit for message with function call - [#1192](https://github.com/livekit/agents/pull/1192) ([@longcw](https://github.com/longcw)) + +- fix: Handle optional func args in tool calls when set to `None` - [#1211](https://github.com/livekit/agents/pull/1211) ([@jayeshp19](https://github.com/jayeshp19)) + +## 0.12.1 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.12.0 + +### Minor Changes + +- add nested speech handles, now agent.say works during a function call - [#1130](https://github.com/livekit/agents/pull/1130) ([@longcw](https://github.com/longcw)) + +### Patch Changes + +- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom)) + +- expose LiveKitAPI from the a JobContext - [#1159](https://github.com/livekit/agents/pull/1159) ([@theomonnom](https://github.com/theomonnom)) + +- add extra chat messages to the end of the function call outputs - [#1165](https://github.com/livekit/agents/pull/1165) ([@longcw](https://github.com/longcw)) + +- Add retries to recover from text mode to audio model for realtime API - [#1121](https://github.com/livekit/agents/pull/1121) ([@longcw](https://github.com/longcw)) + +- prepare for release - [#1160](https://github.com/livekit/agents/pull/1160) ([@theomonnom](https://github.com/theomonnom)) + +- add max_job_memory_usage and will kill the job if it exceeds the limit - [#1136](https://github.com/livekit/agents/pull/1136) ([@longcw](https://github.com/longcw)) + +- support for custom tool use in LLMs - [#1102](https://github.com/livekit/agents/pull/1102) ([@jayeshp19](https://github.com/jayeshp19)) + +- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom)) + +- Expose multimodal agent metrics - [#1080](https://github.com/livekit/agents/pull/1080) ([@longcw](https://github.com/longcw)) + +- preload mp3 decoder for TTS plugins - [#1129](https://github.com/livekit/agents/pull/1129) ([@jayeshp19](https://github.com/jayeshp19)) + +- feat: llm retry & llm.FallbackAdapter - [#1132](https://github.com/livekit/agents/pull/1132) ([@theomonnom](https://github.com/theomonnom)) + +- feat: inference process & end of utterance plugin - [#1133](https://github.com/livekit/agents/pull/1133) ([@theomonnom](https://github.com/theomonnom)) + +- vertex ai support with openai library - [#1084](https://github.com/livekit/agents/pull/1084) ([@jayeshp19](https://github.com/jayeshp19)) + ## 0.11.3 ### Patch Changes diff --git a/livekit-agents/livekit/agents/_exceptions.py b/livekit-agents/livekit/agents/_exceptions.py index 128efacee..74a1ab3c1 100644 --- a/livekit-agents/livekit/agents/_exceptions.py +++ b/livekit-agents/livekit/agents/_exceptions.py @@ -23,16 +23,22 @@ class APIError(Exception): body: object | None """The API response body, if available. - + If the API returned a valid json, the body will contains the decodede result. """ - def __init__(self, message: str, *, body: object | None) -> None: + retryable: bool = False + """Whether the error can be retried.""" + + def __init__( + self, message: str, *, body: object | None, retryable: bool = True + ) -> None: super().__init__(message) self.message = message self.body = body + self.retryable = retryable class APIStatusError(APIError): @@ -48,11 +54,18 @@ def __init__( self, message: str, *, - status_code: int, - request_id: str | None, - body: object | None, + status_code: int = -1, + request_id: str | None = None, + body: object | None = None, + retryable: bool | None = None, ) -> None: - super().__init__(message, body=body) + if retryable is None: + retryable = True + # 4xx errors are not retryable + if status_code >= 400 and status_code < 500: + retryable = False + + super().__init__(message, body=body, retryable=retryable) self.status_code = status_code self.request_id = request_id @@ -61,12 +74,16 @@ def __init__( class APIConnectionError(APIError): """Raised when an API request failed due to a connection error.""" - def __init__(self, message: str = "Connection error.") -> None: - super().__init__(message, body=None) + def __init__( + self, message: str = "Connection error.", *, retryable: bool = True + ) -> None: + super().__init__(message, body=None, retryable=retryable) class APITimeoutError(APIConnectionError): """Raised when an API request timed out.""" - def __init__(self, message: str = "Request timed out.") -> None: - super().__init__(message) + def __init__( + self, message: str = "Request timed out.", *, retryable: bool = True + ) -> None: + super().__init__(message, retryable=retryable) diff --git a/livekit-agents/livekit/agents/cli/log.py b/livekit-agents/livekit/agents/cli/log.py index dc16bfdfa..c4b5e5e52 100644 --- a/livekit-agents/livekit/agents/cli/log.py +++ b/livekit-agents/livekit/agents/cli/log.py @@ -18,6 +18,7 @@ "openai", "watchfiles", "anthropic", + "websockets.client", ] diff --git a/livekit-agents/livekit/agents/cli/proto.py b/livekit-agents/livekit/agents/cli/proto.py index f7753c579..761690783 100644 --- a/livekit-agents/livekit/agents/cli/proto.py +++ b/livekit-agents/livekit/agents/cli/proto.py @@ -52,6 +52,7 @@ def write(self, b: io.BytesIO) -> None: channel.write_string(b, accept_args.metadata) channel.write_string(b, running_job.url) channel.write_string(b, running_job.token) + channel.write_string(b, running_job.worker_id) channel.write_int(b, self.reload_count) @@ -69,6 +70,7 @@ def read(self, b: io.BytesIO) -> None: job=job, url=channel.read_string(b), token=channel.read_string(b), + worker_id=channel.read_string(b), ) ) diff --git a/livekit-agents/livekit/agents/inference_runner.py b/livekit-agents/livekit/agents/inference_runner.py new file mode 100644 index 000000000..646a03bdd --- /dev/null +++ b/livekit-agents/livekit/agents/inference_runner.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import threading +from abc import ABC, abstractmethod +from typing import ClassVar, Protocol, Type + + +class _RunnerMeta(Protocol): + INFERENCE_METHOD: ClassVar[str] + + +_RunnersDict = dict[str, Type["_InferenceRunner"]] + + +# kept private until we stabilize the API (only used for EOU today) +class _InferenceRunner(ABC, _RunnerMeta): + registered_runners: _RunnersDict = {} + + @classmethod + def register_runner(cls, runner_class: Type["_InferenceRunner"]) -> None: + if threading.current_thread() != threading.main_thread(): + raise RuntimeError("InferenceRunner must be registered on the main thread") + + if runner_class.INFERENCE_METHOD in cls.registered_runners: + raise ValueError( + f"InferenceRunner {runner_class.INFERENCE_METHOD} already registered" + ) + + cls.registered_runners[runner_class.INFERENCE_METHOD] = runner_class + + @abstractmethod + def initialize(self) -> None: + """Initialize the runner. This is used to load models, etc.""" + ... + + @abstractmethod + def run(self, data: bytes) -> bytes | None: + """Run inference on the given data.""" + ... diff --git a/livekit-agents/livekit/agents/ipc/__init__.py b/livekit-agents/livekit/agents/ipc/__init__.py index ab04d6b5e..589936600 100644 --- a/livekit-agents/livekit/agents/ipc/__init__.py +++ b/livekit-agents/livekit/agents/ipc/__init__.py @@ -1,17 +1,19 @@ from . import ( channel, + inference_proc_executor, job_executor, - proc_job_executor, + job_proc_executor, + job_thread_executor, proc_pool, proto, - thread_job_executor, ) __all__ = [ "proto", "channel", "proc_pool", - "proc_job_executor", - "thread_job_executor", + "job_proc_executor", + "job_thread_executor", + "inference_proc_executor", "job_executor", ] diff --git a/livekit-agents/livekit/agents/ipc/inference_executor.py b/livekit-agents/livekit/agents/ipc/inference_executor.py new file mode 100644 index 000000000..c83aee64d --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/inference_executor.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from typing import Protocol + + +class InferenceExecutor(Protocol): + async def do_inference(self, method: str, data: bytes) -> bytes | None: ... diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_executor.py b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py new file mode 100644 index 000000000..fa9625213 --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/inference_proc_executor.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import asyncio +import contextlib +import multiprocessing as mp +import socket +from multiprocessing.context import BaseContext + +from ..inference_runner import _RunnersDict +from ..log import logger +from ..utils import aio, log_exceptions, shortuuid +from . import channel, proto +from .inference_proc_lazy_main import ProcStartArgs, proc_main +from .supervised_proc import SupervisedProc + + +class InferenceProcExecutor(SupervisedProc): + def __init__( + self, + *, + runners: _RunnersDict, + initialize_timeout: float, + close_timeout: float, + memory_warn_mb: float, + memory_limit_mb: float, + ping_interval: float, + ping_timeout: float, + high_ping_threshold: float, + mp_ctx: BaseContext, + loop: asyncio.AbstractEventLoop, + ) -> None: + super().__init__( + initialize_timeout=initialize_timeout, + close_timeout=close_timeout, + memory_warn_mb=memory_warn_mb, + memory_limit_mb=memory_limit_mb, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + high_ping_threshold=high_ping_threshold, + mp_ctx=mp_ctx, + loop=loop, + ) + + self._runners = runners + self._active_requests: dict[str, asyncio.Future[proto.InferenceResponse]] = {} + + def _create_process(self, cch: socket.socket, log_cch: socket.socket) -> mp.Process: + proc_args = ProcStartArgs( + log_cch=log_cch, + mp_cch=cch, + runners=self._runners, + ) + + return self._mp_ctx.Process( # type: ignore + target=proc_main, + args=(proc_args,), + name="inference_proc", + ) + + @log_exceptions(logger=logger) + async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None: + async for msg in ipc_ch: + if isinstance(msg, proto.InferenceResponse): + fut = self._active_requests.pop(msg.request_id, None) + if fut is None: + logger.warning( + "received unexpected inference response", + extra={"request_id": msg.request_id}, + ) + return + + with contextlib.suppress(asyncio.InvalidStateError): + fut.set_result(msg) + + async def do_inference(self, method: str, data: bytes) -> bytes | None: + if not self.started: + raise RuntimeError("process not started") + + request_id = shortuuid("inference_req_") + fut = asyncio.Future[proto.InferenceResponse]() + + await channel.asend_message( + self._pch, + proto.InferenceRequest(request_id=request_id, method=method, data=data), + ) + + self._active_requests[request_id] = fut + + inf_resp = await fut + if inf_resp.error: + raise RuntimeError(f"inference of {method} failed: {inf_resp.error}") + + return inf_resp.data + + def logging_extra(self): + extra = super().logging_extra() + extra["inference"] = True + return extra diff --git a/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py new file mode 100644 index 000000000..c4e949d58 --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/inference_proc_lazy_main.py @@ -0,0 +1,108 @@ +from multiprocessing import current_process + +if current_process().name == "inference_proc": + import signal + import sys + + # ignore signals in the inference process (the parent process will handle them) + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + + def _no_traceback_excepthook(exc_type, exc_val, traceback): + if isinstance(exc_val, KeyboardInterrupt): + return + sys.__excepthook__(exc_type, exc_val, traceback) + + sys.excepthook = _no_traceback_excepthook + + +import asyncio +import socket +from dataclasses import dataclass + +from ..inference_runner import _RunnersDict +from ..log import logger +from ..utils import aio, log_exceptions +from . import proto +from .channel import Message +from .proc_client import _ProcClient + + +@dataclass +class ProcStartArgs: + log_cch: socket.socket + mp_cch: socket.socket + runners: _RunnersDict + + +def proc_main(args: ProcStartArgs) -> None: + from .proc_client import _ProcClient + + inf_proc = _InferenceProc(args.runners) + + client = _ProcClient( + args.mp_cch, + args.log_cch, + inf_proc.initialize, + inf_proc.entrypoint, + ) + + client.initialize_logger() + + pid = current_process().pid + logger.info("initializing inference process", extra={"pid": pid}) + client.initialize() + logger.info("inference process initialized", extra={"pid": pid}) + + client.run() + + +class _InferenceProc: + def __init__(self, runners: _RunnersDict) -> None: + # create an instance of each runner (the ctor must not requires any argument) + self._runners = {name: runner() for name, runner in runners.items()} + + def initialize( + self, init_req: proto.InitializeRequest, client: _ProcClient + ) -> None: + self._client = client + + for runner in self._runners.values(): + logger.debug( + "initializing inference runner", + extra={"runner": runner.__class__.INFERENCE_METHOD}, + ) + runner.initialize() + + @log_exceptions(logger=logger) + async def entrypoint(self, cch: aio.ChanReceiver[Message]) -> None: + async for msg in cch: + if isinstance(msg, proto.InferenceRequest): + await self._handle_inference_request(msg) + + if isinstance(msg, proto.ShutdownRequest): + await self._client.send(proto.Exiting(reason=msg.reason)) + break + + async def _handle_inference_request(self, msg: proto.InferenceRequest) -> None: + loop = asyncio.get_running_loop() + + if msg.method not in self._runners: + logger.warning("unknown inference method", extra={"method": msg.method}) + + try: + data = await loop.run_in_executor( + None, self._runners[msg.method].run, msg.data + ) + await self._client.send( + proto.InferenceResponse( + request_id=msg.request_id, + data=data, + ) + ) + + except Exception as e: + logger.exception("error running inference") + await self._client.send( + proto.InferenceResponse(request_id=msg.request_id, error=str(e)) + ) diff --git a/livekit-agents/livekit/agents/ipc/job_executor.py b/livekit-agents/livekit/agents/ipc/job_executor.py index 19704791a..dccf1831d 100644 --- a/livekit-agents/livekit/agents/ipc/job_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_executor.py @@ -11,19 +11,16 @@ class JobExecutor(Protocol): def started(self) -> bool: ... @property - def start_arguments(self) -> Any | None: ... + def user_arguments(self) -> Any | None: ... - @start_arguments.setter - def start_arguments(self, value: Any | None) -> None: ... + @user_arguments.setter + def user_arguments(self, value: Any | None) -> None: ... @property def running_job(self) -> RunningJobInfo | None: ... @property - def run_status(self) -> RunStatus: ... - - @property - def exception(self) -> Exception | None: ... + def status(self) -> JobStatus: ... async def start(self) -> None: ... @@ -36,25 +33,7 @@ async def aclose(self) -> None: ... async def launch_job(self, info: RunningJobInfo) -> None: ... -class RunStatus(Enum): - STARTING = "STARTING" - WAITING_FOR_JOB = "WAITING_FOR_JOB" - RUNNING_JOB = "RUNNING_JOB" - FINISHED_FAILED = "FINISHED_FAILED" - FINISHED_CLEAN = "FINISHED_CLEAN" - - -class JobExecutorError(Exception): - pass - - -class JobExecutorError_ShutdownTimeout(JobExecutorError): - pass - - -class JobExecutorError_Unresponsive(JobExecutorError): - pass - - -class JobExecutorError_Runtime(JobExecutorError): - pass +class JobStatus(Enum): + RUNNING = "running" + FAILED = "failed" + SUCCESS = "success" diff --git a/livekit-agents/livekit/agents/ipc/job_main.py b/livekit-agents/livekit/agents/ipc/job_main.py deleted file mode 100644 index 4e7519400..000000000 --- a/livekit-agents/livekit/agents/ipc/job_main.py +++ /dev/null @@ -1,315 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import copy -import logging -import pickle -import queue -import socket -import sys -import threading -from dataclasses import dataclass -from typing import Any, Callable, Optional - -from livekit import rtc - -from .. import utils -from ..job import JobContext, JobProcess -from ..log import logger -from ..utils.aio import duplex_unix -from . import channel, proto - - -class LogQueueHandler(logging.Handler): - _sentinal = None - - def __init__(self, duplex: utils.aio.duplex_unix._Duplex) -> None: - super().__init__() - self._duplex = duplex - self._send_q = queue.SimpleQueue[Optional[bytes]]() - self._send_thread = threading.Thread( - target=self._forward_logs, name="ipc_log_forwarder" - ) - self._send_thread.start() - - def _forward_logs(self): - while True: - serialized_record = self._send_q.get() - if serialized_record is None: - break - - try: - self._duplex.send_bytes(serialized_record) - except duplex_unix.DuplexClosed: - break - - self._duplex.close() - - def emit(self, record: logging.LogRecord) -> None: - try: - # Check if Python is shutting down - if sys.is_finalizing(): - return - - # from https://github.com/python/cpython/blob/91b7f2e7f6593acefda4fa860250dd87d6f849bf/Lib/logging/handlers.py#L1453 - msg = self.format(record) - record = copy.copy(record) - record.message = msg - record.msg = msg - record.args = None - record.exc_info = None - record.exc_text = None - record.stack_info = None - - # https://websockets.readthedocs.io/en/stable/topics/logging.html#logging-to-json - # webosckets library add "websocket" attribute to log records, which is not pickleable - if hasattr(record, "websocket"): - record.websocket = None - - self._send_q.put_nowait(pickle.dumps(record)) - - except Exception: - self.handleError(record) - - def close(self) -> None: - super().close() - self._send_q.put_nowait(self._sentinal) - - -@dataclass -class _ShutdownInfo: - user_initiated: bool - reason: str - - -@dataclass -class JobTask: - job_ctx: JobContext - task: asyncio.Task - shutdown_fut: asyncio.Future[_ShutdownInfo] - - -def _start_job( - proc: JobProcess, - job_entrypoint_fnc: Callable[[JobContext], Any], - start_req: proto.StartJobRequest, - exit_proc_fut: asyncio.Event, - cch: utils.aio.duplex_unix._AsyncDuplex, -) -> JobTask: - # used to warn users if none of connect/shutdown is called inside the job_entry - ctx_connect, ctx_shutdown = False, False - room = rtc.Room() - request_shutdown_fut = asyncio.Future[_ShutdownInfo]() - - @room.on("disconnected") - def _on_room_disconnected(*args): - with contextlib.suppress(asyncio.InvalidStateError): - request_shutdown_fut.set_result( - _ShutdownInfo(user_initiated=False, reason="room disconnected") - ) - - def _on_ctx_connect() -> None: - nonlocal ctx_connect - ctx_connect = True - - def _on_ctx_shutdown(reason: str) -> None: - nonlocal ctx_shutdown - ctx_shutdown = True - - with contextlib.suppress(asyncio.InvalidStateError): - request_shutdown_fut.set_result( - _ShutdownInfo(user_initiated=True, reason=reason) - ) - - info = start_req.running_job - room._info.name = info.job.room.name - job_ctx = JobContext( - proc=proc, - info=info, - room=room, - on_connect=_on_ctx_connect, - on_shutdown=_on_ctx_shutdown, - ) - - @utils.log_exceptions(logger=logger) - async def _run_job_task() -> None: - utils.http_context._new_session_ctx() - job_entry_task = asyncio.create_task( - job_entrypoint_fnc(job_ctx), name="job_entrypoint" - ) - - async def _warn_not_connected_task(): - await asyncio.sleep(10) - if not ctx_connect and not ctx_shutdown: - logger.warn( - ( - "room not connected after job_entry was called after 10 seconds, " - "did you forget to call job_ctx.connect()?" - ) - ) - - warn_unconnected_task = asyncio.create_task(_warn_not_connected_task()) - job_entry_task.add_done_callback(lambda _: warn_unconnected_task.cancel()) - - def log_exception(t: asyncio.Task) -> None: - if not t.cancelled() and t.exception(): - logger.error( - "unhandled exception while running the job task", - exc_info=t.exception(), - ) - elif not ctx_connect and not ctx_shutdown: - logger.warn("job task completed without connecting or shutting down") - - job_entry_task.add_done_callback(log_exception) - - shutdown_info = await request_shutdown_fut - logger.debug( - "shutting down job task", - extra={ - "reason": shutdown_info.reason, - "user_initiated": shutdown_info.user_initiated, - }, - ) - await channel.asend_message(cch, proto.Exiting(reason=shutdown_info.reason)) - await room.disconnect() - - try: - shutdown_tasks = [] - for callback in job_ctx._shutdown_callbacks: - shutdown_tasks.append( - asyncio.create_task(callback(), name="job_shutdown_callback") - ) - - await asyncio.gather(*shutdown_tasks) - except Exception: - logger.exception("error while shutting down the job") - - await utils.http_context._close_http_ctx() - exit_proc_fut.set() - - task = asyncio.create_task(_run_job_task()) - job_task = JobTask(job_ctx=job_ctx, task=task, shutdown_fut=request_shutdown_fut) - return job_task - - -async def _async_main( - proc: JobProcess, - job_entrypoint_fnc: Callable[[JobContext], Any], - mp_cch: socket.socket, -) -> None: - cch = await duplex_unix._AsyncDuplex.open(mp_cch) - - job_task: JobTask | None = None - exit_proc_fut = asyncio.Event() - no_msg_timeout = utils.aio.sleep(proto.PING_INTERVAL * 5) # missing 5 pings - - @utils.log_exceptions(logger=logger) - async def _read_ipc_task(): - nonlocal job_task - while True: - try: - msg = await channel.arecv_message(cch, proto.IPC_MESSAGES) - except duplex_unix.DuplexClosed: - break - - with contextlib.suppress(utils.aio.SleepFinished): - no_msg_timeout.reset() - - if isinstance(msg, proto.PingRequest): - pong = proto.PongResponse( - last_timestamp=msg.timestamp, timestamp=utils.time_ms() - ) - await channel.asend_message(cch, pong) - - if isinstance(msg, proto.StartJobRequest): - assert job_task is None, "job task already running" - job_task = _start_job(proc, job_entrypoint_fnc, msg, exit_proc_fut, cch) - - if isinstance(msg, proto.ShutdownRequest): - if job_task is None: - # there is no running job, we can exit immediately - break - - with contextlib.suppress(asyncio.InvalidStateError): - job_task.shutdown_fut.set_result( - _ShutdownInfo(reason=msg.reason, user_initiated=False) - ) - - async def _self_health_check(): - await no_msg_timeout - print("worker process is not responding.. worker crashed?") - with contextlib.suppress(asyncio.CancelledError): - exit_proc_fut.set() - - read_task = asyncio.create_task(_read_ipc_task(), name="ipc_read") - health_check_task = asyncio.create_task(_self_health_check(), name="health_check") - - def _done_cb(task: asyncio.Task) -> None: - with contextlib.suppress(asyncio.InvalidStateError): - exit_proc_fut.set() - - read_task.add_done_callback(_done_cb) - - await exit_proc_fut.wait() - await utils.aio.gracefully_cancel(read_task, health_check_task) - - with contextlib.suppress(duplex_unix.DuplexClosed): - await cch.aclose() - - -@dataclass -class ProcStartArgs: - initialize_process_fnc: Callable[[JobProcess], Any] - job_entrypoint_fnc: Callable[[JobContext], Any] - log_cch: socket.socket - mp_cch: socket.socket - asyncio_debug: bool - user_arguments: Any | None = None - - -@dataclass -class ThreadStartArgs: - mp_cch: socket.socket - initialize_process_fnc: Callable[[JobProcess], Any] - job_entrypoint_fnc: Callable[[JobContext], Any] - user_arguments: Any | None - asyncio_debug: bool - join_fnc: Callable[[], None] - - -def thread_main( - args: ThreadStartArgs, -) -> None: - """main function for the job process when using the ThreadedJobRunner""" - tid = threading.get_native_id() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.set_debug(args.asyncio_debug) - loop.slow_callback_duration = 0.1 # 100ms - - cch = duplex_unix._Duplex.open(args.mp_cch) - try: - init_req = channel.recv_message(cch, proto.IPC_MESSAGES) - assert isinstance( - init_req, proto.InitializeRequest - ), "first message must be InitializeRequest" - job_proc = JobProcess(start_arguments=args.user_arguments) - - logger.debug("initializing job runner", extra={"tid": tid}) - args.initialize_process_fnc(job_proc) - logger.debug("job runner initialized", extra={"tid": tid}) - channel.send_message(cch, proto.InitializeResponse()) - - main_task = loop.create_task( - _async_main(job_proc, args.job_entrypoint_fnc, cch.detach()), - name="job_proc_main", - ) - loop.run_until_complete(main_task) - except duplex_unix.DuplexClosed: - pass - except Exception: - logger.exception("error while running job process", extra={"tid": tid}) - finally: - args.join_fnc() - loop.run_until_complete(loop.shutdown_default_executor()) diff --git a/livekit-agents/livekit/agents/ipc/job_proc_executor.py b/livekit-agents/livekit/agents/ipc/job_proc_executor.py new file mode 100644 index 000000000..84a89766a --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/job_proc_executor.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import asyncio +import multiprocessing as mp +import socket +from multiprocessing.context import BaseContext +from typing import Any, Awaitable, Callable + +from ..job import JobContext, JobProcess, RunningJobInfo +from ..log import logger +from ..utils import aio, log_exceptions +from . import channel, proto +from .inference_executor import InferenceExecutor +from .job_executor import JobStatus +from .job_proc_lazy_main import ProcStartArgs, proc_main +from .supervised_proc import SupervisedProc + + +class ProcJobExecutor(SupervisedProc): + def __init__( + self, + *, + initialize_process_fnc: Callable[[JobProcess], Any], + job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]], + inference_executor: InferenceExecutor | None, + initialize_timeout: float, + close_timeout: float, + memory_warn_mb: float, + memory_limit_mb: float, + ping_interval: float, + ping_timeout: float, + high_ping_threshold: float, + mp_ctx: BaseContext, + loop: asyncio.AbstractEventLoop, + ) -> None: + super().__init__( + initialize_timeout=initialize_timeout, + close_timeout=close_timeout, + memory_warn_mb=memory_warn_mb, + memory_limit_mb=memory_limit_mb, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + high_ping_threshold=high_ping_threshold, + mp_ctx=mp_ctx, + loop=loop, + ) + + self._user_args: Any | None = None + self._job_status: JobStatus | None = None + self._running_job: RunningJobInfo | None = None + self._initialize_process_fnc = initialize_process_fnc + self._job_entrypoint_fnc = job_entrypoint_fnc + self._inference_executor = inference_executor + self._inference_tasks: list[asyncio.Task[None]] = [] + + @property + def status(self) -> JobStatus: + if self._job_status is None: + raise RuntimeError("job status not available") + + return self._job_status + + @property + def user_arguments(self) -> Any | None: + return self._user_args + + @user_arguments.setter + def user_arguments(self, value: Any | None) -> None: + self._user_args = value + + @property + def running_job(self) -> RunningJobInfo | None: + return self._running_job + + def _create_process(self, cch: socket.socket, log_cch: socket.socket) -> mp.Process: + proc_args = ProcStartArgs( + initialize_process_fnc=self._initialize_process_fnc, + job_entrypoint_fnc=self._job_entrypoint_fnc, + log_cch=log_cch, + mp_cch=cch, + user_arguments=self._user_args, + ) + + return self._mp_ctx.Process( # type: ignore + target=proc_main, + args=(proc_args,), + name="job_proc", + ) + + @log_exceptions(logger=logger) + async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None: + try: + async for msg in ipc_ch: + if isinstance(msg, proto.InferenceRequest): + self._inference_tasks.append( + asyncio.create_task(self._do_inference_task(msg)) + ) + finally: + await aio.gracefully_cancel(*self._inference_tasks) + + self._job_status = ( + JobStatus.SUCCESS if self.exitcode == 0 else JobStatus.FAILED + ) + + async def _do_inference_task(self, inf_req: proto.InferenceRequest) -> None: + if self._inference_executor is None: + logger.warning("inference request received but no inference executor") + await channel.asend_message( + self._pch, + proto.InferenceResponse( + request_id=inf_req.request_id, error="no inference executor" + ), + ) + return + + try: + inf_res = await self._inference_executor.do_inference( + inf_req.method, inf_req.data + ) + await channel.asend_message( + self._pch, + proto.InferenceResponse(request_id=inf_req.request_id, data=inf_res), + ) + except Exception as e: + await channel.asend_message( + self._pch, + proto.InferenceResponse(request_id=inf_req.request_id, error=str(e)), + ) + + async def launch_job(self, info: RunningJobInfo) -> None: + """start/assign a job to the process""" + if self._running_job is not None: + raise RuntimeError("process already has a running job") + + if not self._initialize_fut.done(): + raise RuntimeError("process not initialized") + + self._job_status = JobStatus.RUNNING + self._running_job = info + + start_req = proto.StartJobRequest() + start_req.running_job = info + await channel.asend_message(self._pch, start_req) + + def logging_extra(self): + extra = super().logging_extra() + + if self._running_job: + extra["job_id"] = self._running_job.job.id + + return extra diff --git a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py new file mode 100644 index 000000000..531dd7a36 --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +from multiprocessing import current_process + +if current_process().name == "job_proc": + import signal + import sys + + # ignore signals in the jobs process (the parent process will handle them) + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + + def _no_traceback_excepthook(exc_type, exc_val, traceback): + if isinstance(exc_val, KeyboardInterrupt): + return + sys.__excepthook__(exc_type, exc_val, traceback) + + sys.excepthook = _no_traceback_excepthook + + +import asyncio +import contextlib +import socket +import threading +from dataclasses import dataclass +from typing import Any, Callable + +from livekit import rtc + +from ..job import JobContext, JobProcess, _JobContextVar +from ..log import logger +from ..utils import aio, http_context, log_exceptions, shortuuid +from .channel import Message +from .inference_executor import InferenceExecutor +from .proc_client import _ProcClient +from .proto import ( + Exiting, + InferenceRequest, + InferenceResponse, + InitializeRequest, + ShutdownRequest, + StartJobRequest, +) + + +@dataclass +class ProcStartArgs: + initialize_process_fnc: Callable[[JobProcess], Any] + job_entrypoint_fnc: Callable[[JobContext], Any] + mp_cch: socket.socket + log_cch: socket.socket + user_arguments: Any | None = None + + +def proc_main(args: ProcStartArgs) -> None: + from .proc_client import _ProcClient + + job_proc = _JobProc( + args.initialize_process_fnc, args.job_entrypoint_fnc, args.user_arguments + ) + + client = _ProcClient( + args.mp_cch, + args.log_cch, + job_proc.initialize, + job_proc.entrypoint, + ) + + client.initialize_logger() + + pid = current_process().pid + logger.info("initializing job process", extra={"pid": pid}) + client.initialize() + logger.info("job process initialized", extra={"pid": pid}) + + client.run() + + +class _InfClient(InferenceExecutor): + def __init__(self, proc_client: _ProcClient) -> None: + self._client = proc_client + self._active_requests: dict[str, asyncio.Future[InferenceResponse]] = {} + + async def do_inference(self, method: str, data: bytes) -> bytes | None: + request_id = shortuuid("inference_job_") + fut = asyncio.Future[InferenceResponse]() + + await self._client.send( + InferenceRequest(request_id=request_id, method=method, data=data), + ) + + self._active_requests[request_id] = fut + + inf_resp = await fut + if inf_resp.error: + raise RuntimeError(f"inference of {method} failed: {inf_resp.error}") + + return inf_resp.data + + def _on_inference_response(self, resp: InferenceResponse) -> None: + fut = self._active_requests.pop(resp.request_id, None) + if fut is None: + logger.warning( + "received unexpected inference response", extra={"resp": resp} + ) + return + + with contextlib.suppress(asyncio.InvalidStateError): + fut.set_result(resp) + + +@dataclass +class _ShutdownInfo: + user_initiated: bool + reason: str + + +class _JobProc: + def __init__( + self, + initialize_process_fnc: Callable[[JobProcess], Any], + job_entrypoint_fnc: Callable[[JobContext], Any], + user_arguments: Any | None = None, + ) -> None: + self._initialize_process_fnc = initialize_process_fnc + self._job_entrypoint_fnc = job_entrypoint_fnc + self._job_proc = JobProcess(user_arguments=user_arguments) + self._job_task: asyncio.Task | None = None + + # used to warn users if both connect and shutdown are not called inside the job_entry + self._ctx_connect_called = False + self._ctx_shutdown_called = False + + @property + def has_running_job(self) -> bool: + return self._job_task is not None + + def initialize(self, init_req: InitializeRequest, client: _ProcClient) -> None: + self._client = client + self._inf_client = _InfClient(client) + self._initialize_process_fnc(self._job_proc) + + @log_exceptions(logger=logger) + async def entrypoint(self, cch: aio.ChanReceiver[Message]) -> None: + self._exit_proc_flag = asyncio.Event() + self._shutdown_fut: asyncio.Future[_ShutdownInfo] = asyncio.Future() + + @log_exceptions(logger=logger) + async def _read_ipc_task(): + async for msg in cch: + if isinstance(msg, StartJobRequest): + if self.has_running_job: + logger.warning( + "trying to start a new job while one is already running" + ) + continue + + self._start_job(msg) + if isinstance(msg, ShutdownRequest): + if not self.has_running_job: + self._exit_proc_flag.set() + break # exit immediately + + with contextlib.suppress(asyncio.InvalidStateError): + self._shutdown_fut.set_result( + _ShutdownInfo(reason=msg.reason, user_initiated=False) + ) + + if isinstance(msg, InferenceResponse): + self._inf_client._on_inference_response(msg) + + read_task = asyncio.create_task(_read_ipc_task(), name="job_ipc_read") + + await self._exit_proc_flag.wait() + await aio.gracefully_cancel(read_task) + + def _start_job(self, msg: StartJobRequest) -> None: + self._room = rtc.Room() + + @self._room.on("disconnected") + def _on_room_disconnected(*args): + with contextlib.suppress(asyncio.InvalidStateError): + self._shutdown_fut.set_result( + _ShutdownInfo(user_initiated=False, reason="room disconnected") + ) + + def _on_ctx_connect() -> None: + self._ctx_connect_called = True + + def _on_ctx_shutdown(reason: str) -> None: + self._ctx_shutdown_called = True + + with contextlib.suppress(asyncio.InvalidStateError): + self._shutdown_fut.set_result( + _ShutdownInfo(user_initiated=True, reason=reason) + ) + + self._room._info.name = msg.running_job.job.room.name + + self._job_ctx = JobContext( + proc=self._job_proc, + info=msg.running_job, + room=self._room, + on_connect=_on_ctx_connect, + on_shutdown=_on_ctx_shutdown, + inference_executor=self._inf_client, + ) + + self._job_task = asyncio.create_task(self._run_job_task(), name="job_task") + + def _exit_proc_cb(_: asyncio.Task) -> None: + self._exit_proc_flag.set() + + self._job_task.add_done_callback(_exit_proc_cb) + + async def _run_job_task(self) -> None: + http_context._new_session_ctx() + job_ctx_token = _JobContextVar.set(self._job_ctx) + + job_entry_task = asyncio.create_task( + self._job_entrypoint_fnc(self._job_ctx), name="job_user_entrypoint" + ) + + async def _warn_not_connected_task(): + await asyncio.sleep(10) + if not self._ctx_connect_called and not self._ctx_shutdown_called: + logger.warning( + ( + "The room connection was not established within 10 seconds after calling job_entry. " + "This may indicate that job_ctx.connect() was not called. " + ) + ) + + warn_unconnected_task = asyncio.create_task(_warn_not_connected_task()) + job_entry_task.add_done_callback(lambda _: warn_unconnected_task.cancel()) + + def log_exception(t: asyncio.Task) -> None: + if not t.cancelled() and t.exception(): + logger.error( + "unhandled exception while running the job task", + exc_info=t.exception(), + ) + elif not self._ctx_connect_called and not self._ctx_shutdown_called: + logger.warning( + ( + "The job task completed without establishing a connection or performing a proper shutdown. " + "Ensure that job_ctx.connect()/job_ctx.shutdown() is called and the job is correctly finalized." + ) + ) + + job_entry_task.add_done_callback(log_exception) + + shutdown_info = await self._shutdown_fut + logger.debug( + "shutting down job task", + extra={ + "reason": shutdown_info.reason, + "user_initiated": shutdown_info.user_initiated, + }, + ) + + await self._client.send(Exiting(reason=shutdown_info.reason)) + await self._room.disconnect() + + try: + shutdown_tasks = [] + for callback in self._job_ctx._shutdown_callbacks: + shutdown_tasks.append( + asyncio.create_task(callback(), name="job_shutdown_callback") + ) + + await asyncio.gather(*shutdown_tasks) + except Exception: + logger.exception("error while shutting down the job") + + await http_context._close_http_ctx() + _JobContextVar.reset(job_ctx_token) + + +@dataclass +class ThreadStartArgs: + initialize_process_fnc: Callable[[JobProcess], Any] + job_entrypoint_fnc: Callable[[JobContext], Any] + join_fnc: Callable[[], None] + mp_cch: socket.socket + user_arguments: Any | None + + +def thread_main( + args: ThreadStartArgs, +) -> None: + """main function for the job process when using the ThreadedJobRunner""" + tid = threading.get_native_id() + + try: + from .proc_client import _ProcClient + + job_proc = _JobProc( + args.initialize_process_fnc, args.job_entrypoint_fnc, args.user_arguments + ) + + client = _ProcClient( + args.mp_cch, + None, + job_proc.initialize, + job_proc.entrypoint, + ) + + logger.info("initializing job runner", extra={"tid": tid}) + client.initialize() + logger.info("job runner initialized", extra={"tid": tid}) + + client.run() + finally: + args.join_fnc() diff --git a/livekit-agents/livekit/agents/ipc/thread_job_executor.py b/livekit-agents/livekit/agents/ipc/job_thread_executor.py similarity index 70% rename from livekit-agents/livekit/agents/ipc/thread_job_executor.py rename to livekit-agents/livekit/agents/ipc/job_thread_executor.py index b6908669d..6705422ab 100644 --- a/livekit-agents/livekit/agents/ipc/thread_job_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_thread_executor.py @@ -11,12 +11,9 @@ from ..job import JobContext, JobProcess, RunningJobInfo from ..log import logger from ..utils.aio import duplex_unix -from . import channel, job_main, proto -from .job_executor import ( - JobExecutorError_ShutdownTimeout, - JobExecutorError_Unresponsive, - RunStatus, -) +from . import channel, job_proc_lazy_main, proto +from .inference_executor import InferenceExecutor +from .job_executor import JobStatus @dataclass @@ -25,6 +22,8 @@ class _ProcOpts: job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]] initialize_timeout: float close_timeout: float + ping_interval: float + high_ping_threshold: float class ThreadJobExecutor: @@ -33,8 +32,11 @@ def __init__( *, initialize_process_fnc: Callable[[JobProcess], Any], job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]], + inference_executor: InferenceExecutor | None, initialize_timeout: float, close_timeout: float, + ping_interval: float, + high_ping_threshold: float, loop: asyncio.AbstractEventLoop, ) -> None: self._loop = loop @@ -43,57 +45,45 @@ def __init__( job_entrypoint_fnc=job_entrypoint_fnc, initialize_timeout=initialize_timeout, close_timeout=close_timeout, + ping_interval=ping_interval, + high_ping_threshold=high_ping_threshold, ) self._user_args: Any | None = None + self._job_status: JobStatus | None = None self._running_job: RunningJobInfo | None = None - self._exception: Exception | None = None self._main_atask: asyncio.Task[None] | None = None - self._closing = False self._initialize_fut = asyncio.Future[None]() - + self._closing = False self._lock = asyncio.Lock() + self._inference_executor = inference_executor + self._inference_tasks: list[asyncio.Task[None]] = [] + + @property + def status(self) -> JobStatus: + if self._job_status is None: + raise RuntimeError("job status not available") + + return self._job_status + @property def started(self) -> bool: return self._main_atask is not None @property - def start_arguments(self) -> Any | None: + def user_arguments(self) -> Any | None: return self._user_args - @start_arguments.setter - def start_arguments(self, value: Any | None) -> None: + @user_arguments.setter + def user_arguments(self, value: Any | None) -> None: self._user_args = value @property def running_job(self) -> RunningJobInfo | None: return self._running_job - @property - def exception(self) -> Exception | None: - return self._exception - - @property - def run_status(self) -> RunStatus: - if not self._running_job: - if self.started: - return RunStatus.WAITING_FOR_JOB - else: - return RunStatus.STARTING - - if not self._main_atask: - return RunStatus.STARTING - - if self._main_atask.done(): - if self.exception: - return RunStatus.FINISHED_FAILED - else: - return RunStatus.FINISHED_CLEAN - else: - return RunStatus.RUNNING_JOB - async def start(self) -> None: if self.started: raise RuntimeError("runner already started") @@ -116,17 +106,16 @@ def _on_join() -> None: with contextlib.suppress(RuntimeError): self._loop.call_soon_threadsafe(self._join_fut.set_result, None) - targs = job_main.ThreadStartArgs( + targs = job_proc_lazy_main.ThreadStartArgs( mp_cch=mp_cch, initialize_process_fnc=self._opts.initialize_process_fnc, job_entrypoint_fnc=self._opts.job_entrypoint_fnc, user_arguments=self._user_args, - asyncio_debug=self._loop.get_debug(), join_fnc=_on_join, ) self._thread = t = threading.Thread( - target=job_main.thread_main, + target=job_proc_lazy_main.thread_main, args=(targs,), name="job_thread_runner", ) @@ -187,7 +176,6 @@ async def aclose(self) -> None: asyncio.shield(self._main_atask), timeout=self._opts.close_timeout ) except asyncio.TimeoutError: - self._exception = JobExecutorError_ShutdownTimeout() logger.error( "job shutdown is taking too much time..", extra=self.logging_extra() ) @@ -196,12 +184,42 @@ async def aclose(self) -> None: if self._main_atask: await asyncio.shield(self._main_atask) + async def _do_inference_task(self, inf_req: proto.InferenceRequest) -> None: + if self._inference_executor is None: + logger.warning("inference request received but no inference executor") + await channel.asend_message( + self._pch, + proto.InferenceResponse( + request_id=inf_req.request_id, error="no inference executor" + ), + ) + return + + try: + inf_res = await self._inference_executor.do_inference( + inf_req.method, inf_req.data + ) + await channel.asend_message( + self._pch, + proto.InferenceResponse(request_id=inf_req.request_id, data=inf_res), + ) + except Exception as e: + await channel.asend_message( + self._pch, + proto.InferenceResponse(request_id=inf_req.request_id, error=str(e)), + ) + async def launch_job(self, info: RunningJobInfo) -> None: """start/assign a job to the executor""" if self._running_job is not None: raise RuntimeError("executor already has a running job") + if not self._initialize_fut.done(): + raise RuntimeError("executor not initialized") + self._running_job = info + self._job_status = JobStatus.RUNNING + start_req = proto.StartJobRequest() start_req.running_job = info await channel.asend_message(self._pch, start_req) @@ -215,18 +233,20 @@ async def _main_task(self) -> None: except Exception: pass # initialization failed - pong_timeout = utils.aio.sleep(proto.PING_TIMEOUT) - ping_task = asyncio.create_task(self._ping_pong_task(pong_timeout)) - monitor_task = asyncio.create_task(self._monitor_task(pong_timeout)) + ping_task = asyncio.create_task(self._ping_task()) + monitor_task = asyncio.create_task(self._monitor_task()) await self._join_fut await utils.aio.gracefully_cancel(ping_task, monitor_task) + await utils.aio.gracefully_cancel(*self._inference_tasks) with contextlib.suppress(duplex_unix.DuplexClosed): await self._pch.aclose() + self._job_status = JobStatus.SUCCESS + @utils.log_exceptions(logger=logger) - async def _monitor_task(self, pong_timeout: utils.aio.Sleep) -> None: + async def _monitor_task(self) -> None: while True: try: msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES) @@ -235,47 +255,33 @@ async def _monitor_task(self, pong_timeout: utils.aio.Sleep) -> None: if isinstance(msg, proto.PongResponse): delay = utils.time_ms() - msg.timestamp - if delay > proto.HIGH_PING_THRESHOLD * 1000: + if delay > self._opts.high_ping_threshold * 1000: logger.warning( "job executor is unresponsive", extra={"delay": delay, **self.logging_extra()}, ) - with contextlib.suppress(utils.aio.SleepFinished): - pong_timeout.reset() - if isinstance(msg, proto.Exiting): logger.debug( "job exiting", extra={"reason": msg.reason, **self.logging_extra()} ) + if isinstance(msg, proto.InferenceRequest): + self._inference_tasks.append( + asyncio.create_task(self._do_inference_task(msg)) + ) + @utils.log_exceptions(logger=logger) - async def _ping_pong_task(self, pong_timeout: utils.aio.Sleep) -> None: - ping_interval = utils.aio.interval(proto.PING_INTERVAL) - - async def _send_ping_co(): - while True: - await ping_interval.tick() - try: - await channel.asend_message( - self._pch, proto.PingRequest(timestamp=utils.time_ms()) - ) - except utils.aio.duplex_unix.DuplexClosed: - break - - async def _pong_timeout_co(): - await pong_timeout - self._exception = JobExecutorError_Unresponsive() - logger.error("job is unresponsive..", extra=self.logging_extra()) - - tasks = [ - asyncio.create_task(_send_ping_co()), - asyncio.create_task(_pong_timeout_co()), - ] - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) + async def _ping_task(self) -> None: + ping_interval = utils.aio.interval(self._opts.ping_interval) + while True: + await ping_interval.tick() + try: + await channel.asend_message( + self._pch, proto.PingRequest(timestamp=utils.time_ms()) + ) + except utils.aio.duplex_unix.DuplexClosed: + break def logging_extra(self): extra: dict[str, Any] = { diff --git a/livekit-agents/livekit/agents/ipc/log_queue.py b/livekit-agents/livekit/agents/ipc/log_queue.py new file mode 100644 index 000000000..38115cff1 --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/log_queue.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import copy +import logging +import pickle +import queue +import sys +import threading +from typing import Callable, Optional + +from .. import utils +from ..utils.aio import duplex_unix + + +class LogQueueListener: + def __init__( + self, + duplex: utils.aio.duplex_unix._Duplex, + prepare_fnc: Callable[[logging.LogRecord], None], + ): + self._thread: threading.Thread | None = None + self._duplex = duplex + self._prepare_fnc = prepare_fnc + + def start(self) -> None: + self._thread = threading.Thread(target=self._monitor, name="ipc_log_listener") + self._thread.start() + + def stop(self) -> None: + if self._thread is None: + return + + self._duplex.close() + self._thread.join() + self._thread = None + + def handle(self, record: logging.LogRecord) -> None: + self._prepare_fnc(record) + + lger = logging.getLogger(record.name) + if not lger.isEnabledFor(record.levelno): + return + + lger.callHandlers(record) + + def _monitor(self): + while True: + try: + data = self._duplex.recv_bytes() + except utils.aio.duplex_unix.DuplexClosed: + break + + record = pickle.loads(data) + self.handle(record) + + +class LogQueueHandler(logging.Handler): + _sentinal = None + + def __init__(self, duplex: utils.aio.duplex_unix._Duplex) -> None: + super().__init__() + self._duplex = duplex + self._send_q = queue.SimpleQueue[Optional[bytes]]() + self._send_thread = threading.Thread( + target=self._forward_logs, name="ipc_log_forwarder" + ) + self._send_thread.start() + + def _forward_logs(self): + while True: + serialized_record = self._send_q.get() + if serialized_record is None: + break + + try: + self._duplex.send_bytes(serialized_record) + except duplex_unix.DuplexClosed: + break + + self._duplex.close() + + def emit(self, record: logging.LogRecord) -> None: + try: + # Check if Python is shutting down + if sys.is_finalizing(): + return + + # from https://github.com/python/cpython/blob/91b7f2e7f6593acefda4fa860250dd87d6f849bf/Lib/logging/handlers.py#L1453 + msg = self.format(record) + record = copy.copy(record) + record.message = msg + record.msg = msg + record.args = None + record.exc_info = None + record.exc_text = None + record.stack_info = None + + # https://websockets.readthedocs.io/en/stable/topics/logging.html#logging-to-json + # webosckets library add "websocket" attribute to log records, which is not pickleable + if hasattr(record, "websocket"): + record.websocket = None + + self._send_q.put_nowait(pickle.dumps(record)) + + except Exception: + self.handleError(record) + + def close(self) -> None: + super().close() + self._send_q.put_nowait(self._sentinal) diff --git a/livekit-agents/livekit/agents/ipc/proc_client.py b/livekit-agents/livekit/agents/ipc/proc_client.py new file mode 100644 index 000000000..76b77fb88 --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/proc_client.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import asyncio +import contextlib +import logging +import socket +import sys +from typing import Callable, Coroutine + +from ..log import logger +from ..utils import aio, log_exceptions, time_ms +from .channel import Message, arecv_message, asend_message, recv_message, send_message +from .log_queue import LogQueueHandler +from .proto import ( + IPC_MESSAGES, + InitializeRequest, + InitializeResponse, + PingRequest, + PongResponse, +) + + +class _ProcClient: + def __init__( + self, + mp_cch: socket.socket, + log_cch: socket.socket | None, + initialize_fnc: Callable[[InitializeRequest, "_ProcClient"], None], + main_task_fnc: Callable[ + [aio.ChanReceiver[Message]], Coroutine[None, None, None] + ], + ) -> None: + self._mp_cch = mp_cch + self._log_cch = log_cch + self._initialize_fnc = initialize_fnc + self._main_task_fnc = main_task_fnc + self._initialized = False + self._log_handler: LogQueueHandler | None = None + + def initialize_logger(self) -> None: + if self._log_cch is None: + raise RuntimeError("cannot initialize logger without log channel") + + root_logger = logging.getLogger() + root_logger.setLevel(logging.NOTSET) + + log_cch = aio.duplex_unix._Duplex.open(self._log_cch) + self._log_handler = LogQueueHandler(log_cch) + root_logger.addHandler(self._log_handler) + + def initialize(self) -> None: + try: + cch = aio.duplex_unix._Duplex.open(self._mp_cch) + first_req = recv_message(cch, IPC_MESSAGES) + + assert isinstance( + first_req, InitializeRequest + ), "first message must be proto.InitializeRequest" + + self._init_req = first_req + self._initialize_fnc(self._init_req, self) + send_message(cch, InitializeResponse()) + self._initialized = True + cch.detach() + except aio.duplex_unix.DuplexClosed as e: + raise RuntimeError("failed to initialize proc_client") from e + + def run(self) -> None: + if not self._initialized: + raise RuntimeError("proc_client not initialized") + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.set_debug(self._init_req.asyncio_debug) + loop.slow_callback_duration = 0.1 # 100ms + aio.debug.hook_slow_callbacks(2.0) + + try: + self._task = loop.create_task(self._monitor_task(), name="proc_client_main") + while not self._task.done(): + try: + loop.run_until_complete(self._task) + except KeyboardInterrupt: + # ignore the keyboard interrupt, we handle the process shutdown ourselves on the worker process + # (See proto.ShutdownRequest) + pass + except KeyboardInterrupt: + pass + finally: + if self._log_handler is not None: + self._log_handler.close() + + loop.run_until_complete(loop.shutdown_default_executor()) + + async def send(self, msg: Message) -> None: + await asend_message(self._acch, msg) + + async def _monitor_task(self) -> None: + self._acch = await aio.duplex_unix._AsyncDuplex.open(self._mp_cch) + try: + exit_flag = asyncio.Event() + ping_timeout = aio.sleep(self._init_req.ping_timeout) + + ipc_ch = aio.Chan[Message]() + + @log_exceptions(logger=logger) + async def _read_ipc_task(): + while True: + try: + msg = await arecv_message(self._acch, IPC_MESSAGES) + except aio.duplex_unix.DuplexClosed: + break + + with contextlib.suppress(aio.SleepFinished): + ping_timeout.reset() + + if isinstance(msg, PingRequest): + await asend_message( + self._acch, + PongResponse( + last_timestamp=msg.timestamp, timestamp=time_ms() + ), + ) + + ipc_ch.send_nowait(msg) + + @log_exceptions(logger=logger) + async def _self_health_check(): + await ping_timeout + print( + "worker process is not responding.. worker crashed?", + file=sys.stderr, + ) + + read_task = asyncio.create_task(_read_ipc_task(), name="ipc_read") + health_check_task: asyncio.Task | None = None + if self._init_req.ping_interval > 0: + health_check_task = asyncio.create_task( + _self_health_check(), name="health_check" + ) + main_task = asyncio.create_task( + self._main_task_fnc(ipc_ch), name="main_task_entrypoint" + ) + + def _done_cb(_: asyncio.Task) -> None: + with contextlib.suppress(asyncio.InvalidStateError): + exit_flag.set() + + ipc_ch.close() + + read_task.add_done_callback(_done_cb) + if health_check_task is not None: + health_check_task.add_done_callback(_done_cb) + + main_task.add_done_callback(_done_cb) + + await exit_flag.wait() + await aio.gracefully_cancel(read_task, main_task) + if health_check_task is not None: + await aio.gracefully_cancel(health_check_task) + finally: + await self._acch.aclose() diff --git a/livekit-agents/livekit/agents/ipc/proc_job_executor.py b/livekit-agents/livekit/agents/ipc/proc_job_executor.py deleted file mode 100644 index 2a956d947..000000000 --- a/livekit-agents/livekit/agents/ipc/proc_job_executor.py +++ /dev/null @@ -1,413 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import logging -import pickle -import socket -import sys -import threading -from dataclasses import dataclass -from multiprocessing.context import BaseContext -from typing import Any, Awaitable, Callable - -from .. import utils -from ..job import JobContext, JobProcess, RunningJobInfo -from ..log import logger -from ..utils.aio import duplex_unix -from . import channel, job_main, proc_lazy_main, proto -from .job_executor import ( - JobExecutorError_Runtime, - JobExecutorError_ShutdownTimeout, - JobExecutorError_Unresponsive, - RunStatus, -) - - -class LogQueueListener: - def __init__( - self, - duplex: utils.aio.duplex_unix._Duplex, - prepare_fnc: Callable[[logging.LogRecord], None], - ): - self._thread: threading.Thread | None = None - self._duplex = duplex - self._prepare_fnc = prepare_fnc - - def start(self) -> None: - self._thread = threading.Thread(target=self._monitor, name="ipc_log_listener") - self._thread.start() - - def stop(self) -> None: - if self._thread is None: - return - - self._duplex.close() - self._thread.join() - self._thread = None - - def handle(self, record: logging.LogRecord) -> None: - self._prepare_fnc(record) - - lger = logging.getLogger(record.name) - if not lger.isEnabledFor(record.levelno): - return - - lger.callHandlers(record) - - def _monitor(self): - while True: - try: - data = self._duplex.recv_bytes() - except utils.aio.duplex_unix.DuplexClosed: - break - - record = pickle.loads(data) - self.handle(record) - - -@dataclass -class _ProcOpts: - initialize_process_fnc: Callable[[JobProcess], Any] - job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]] - mp_ctx: BaseContext - initialize_timeout: float - close_timeout: float - - -class ProcJobExecutor: - def __init__( - self, - *, - initialize_process_fnc: Callable[[JobProcess], Any], - job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]], - initialize_timeout: float, - close_timeout: float, - mp_ctx: BaseContext, - loop: asyncio.AbstractEventLoop, - ) -> None: - self._loop = loop - self._opts = _ProcOpts( - initialize_process_fnc=initialize_process_fnc, - job_entrypoint_fnc=job_entrypoint_fnc, - initialize_timeout=initialize_timeout, - close_timeout=close_timeout, - mp_ctx=mp_ctx, - ) - - self._user_args: Any | None = None - self._running_job: RunningJobInfo | None = None - self._exitcode: int | None = None - self._pid: int | None = None - self._exception: Exception | None = None - - self._main_atask: asyncio.Task[None] | None = None - self._closing = False - self._kill_sent = False - self._initialize_fut = asyncio.Future[None]() - - self._lock = asyncio.Lock() - - @property - def exitcode(self) -> int | None: - return self._exitcode - - @property - def killed(self) -> bool: - return self._kill_sent - - @property - def pid(self) -> int | None: - return self._pid - - @property - def started(self) -> bool: - return self._main_atask is not None - - @property - def start_arguments(self) -> Any | None: - return self._user_args - - @start_arguments.setter - def start_arguments(self, value: Any | None) -> None: - self._user_args = value - - @property - def running_job(self) -> RunningJobInfo | None: - return self._running_job - - @property - def exception(self) -> Exception | None: - return self._exception - - @property - def run_status(self) -> RunStatus: - if not self._running_job: - if self.started: - return RunStatus.WAITING_FOR_JOB - else: - return RunStatus.STARTING - - if not self._main_atask: - return RunStatus.STARTING - - if self._main_atask.done(): - if self.exception: - return RunStatus.FINISHED_FAILED - else: - return RunStatus.FINISHED_CLEAN - else: - return RunStatus.RUNNING_JOB - - async def start(self) -> None: - """start the job process""" - if self.started: - raise RuntimeError("process already started") - - if self._closing: - raise RuntimeError("process is closed") - - await asyncio.shield(self._start()) - - async def _start(self) -> None: - def _add_proc_ctx_log(record: logging.LogRecord) -> None: - extra = self.logging_extra() - for key, value in extra.items(): - setattr(record, key, value) - - async with self._lock: - mp_pch, mp_cch = socket.socketpair() - mp_log_pch, mp_log_cch = socket.socketpair() - - self._pch = await duplex_unix._AsyncDuplex.open(mp_pch) - - log_pch = duplex_unix._Duplex.open(mp_log_pch) - log_listener = LogQueueListener(log_pch, _add_proc_ctx_log) - log_listener.start() - - self._proc_args = job_main.ProcStartArgs( - initialize_process_fnc=self._opts.initialize_process_fnc, - job_entrypoint_fnc=self._opts.job_entrypoint_fnc, - log_cch=mp_log_cch, - mp_cch=mp_cch, - asyncio_debug=self._loop.get_debug(), - user_arguments=self._user_args, - ) - - self._proc = self._opts.mp_ctx.Process( # type: ignore - target=proc_lazy_main.proc_main, - args=(self._proc_args,), - name="job_proc", - ) - - self._proc.start() - mp_log_cch.close() - mp_cch.close() - - self._pid = self._proc.pid - self._join_fut = asyncio.Future[None]() - - def _sync_run(): - self._proc.join() - log_listener.stop() - try: - self._loop.call_soon_threadsafe(self._join_fut.set_result, None) - except RuntimeError: - pass - - thread = threading.Thread(target=_sync_run, name="proc_join_thread") - thread.start() - self._main_atask = asyncio.create_task(self._main_task()) - - async def join(self) -> None: - """wait for the job process to finish""" - if not self.started: - raise RuntimeError("process not started") - - async with self._lock: - if self._main_atask: - await asyncio.shield(self._main_atask) - - async def initialize(self) -> None: - """initialize the job process, this is calling the user provided initialize_process_fnc - raise asyncio.TimeoutError if initialization times out""" - await channel.asend_message(self._pch, proto.InitializeRequest()) - - # wait for the process to become ready - try: - init_res = await asyncio.wait_for( - channel.arecv_message(self._pch, proto.IPC_MESSAGES), - timeout=self._opts.initialize_timeout, - ) - assert isinstance( - init_res, proto.InitializeResponse - ), "first message must be InitializeResponse" - except asyncio.TimeoutError: - self._initialize_fut.set_exception( - asyncio.TimeoutError("process initialization timed out") - ) - logger.error( - "initialization timed out, killing job", extra=self.logging_extra() - ) - self._send_kill_signal() - raise - except Exception as e: # should be channel.ChannelClosed most of the time - self._exception = JobExecutorError_Runtime() - self._initialize_fut.set_exception(e) - raise - else: - self._initialize_fut.set_result(None) - - async def aclose(self) -> None: - """attempt to gracefully close the job process""" - if not self.started: - return - - self._closing = True - with contextlib.suppress(utils.aio.duplex_unix.DuplexClosed): - await channel.asend_message(self._pch, proto.ShutdownRequest()) - - try: - if self._main_atask: - await asyncio.wait_for( - asyncio.shield(self._main_atask), timeout=self._opts.close_timeout - ) - except asyncio.TimeoutError: - logger.error( - "process did not exit in time, killing job", extra=self.logging_extra() - ) - self._exception = JobExecutorError_ShutdownTimeout() - self._send_kill_signal() - - async with self._lock: - if self._main_atask: - await asyncio.shield(self._main_atask) - - async def kill(self) -> None: - """forcefully kill the job process""" - if not self.started: - raise RuntimeError("process not started") - - self._closing = True - self._send_kill_signal() - - async with self._lock: - if self._main_atask: - await asyncio.shield(self._main_atask) - - async def launch_job(self, info: RunningJobInfo) -> None: - """start/assign a job to the process""" - if self._running_job is not None: - raise RuntimeError("process already has a running job") - - self._running_job = info - start_req = proto.StartJobRequest() - start_req.running_job = info - await channel.asend_message(self._pch, start_req) - - def _send_kill_signal(self) -> None: - """forcefully kill the job process""" - try: - if not self._proc.is_alive(): - return - except ValueError: - return - - logger.info("killing job process", extra=self.logging_extra()) - if sys.platform == "win32": - self._proc.terminate() - else: - self._proc.kill() - - self._kill_sent = True - - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - try: - await self._initialize_fut - except asyncio.TimeoutError: - pass # this happens when the initialization takes longer than self._initialize_timeout - except Exception: - pass # initialization failed - - # the process is killed if it doesn't respond to ping requests - pong_timeout = utils.aio.sleep(proto.PING_TIMEOUT) - ping_task = asyncio.create_task(self._ping_pong_task(pong_timeout)) - monitor_task = asyncio.create_task(self._monitor_task(pong_timeout)) - - await self._join_fut - self._exitcode = self._proc.exitcode - self._proc.close() - await utils.aio.gracefully_cancel(ping_task, monitor_task) - - with contextlib.suppress(duplex_unix.DuplexClosed): - await self._pch.aclose() - - if self._exitcode != 0 and not self._kill_sent: - self._exception = JobExecutorError_Runtime() - logger.error( - f"job process exited with non-zero exit code {self.exitcode}", - extra=self.logging_extra(), - ) - - @utils.log_exceptions(logger=logger) - async def _monitor_task(self, pong_timeout: utils.aio.Sleep) -> None: - while True: - try: - msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES) - except utils.aio.duplex_unix.DuplexClosed: - break - - if isinstance(msg, proto.PongResponse): - delay = utils.time_ms() - msg.timestamp - if delay > proto.HIGH_PING_THRESHOLD * 1000: - logger.warning( - "job process is unresponsive", - extra={"delay": delay, **self.logging_extra()}, - ) - - with contextlib.suppress(utils.aio.SleepFinished): - pong_timeout.reset() - - if isinstance(msg, proto.Exiting): - logger.info( - "job exiting", extra={"reason": msg.reason, **self.logging_extra()} - ) - - @utils.log_exceptions(logger=logger) - async def _ping_pong_task(self, pong_timeout: utils.aio.Sleep) -> None: - ping_interval = utils.aio.interval(proto.PING_INTERVAL) - - async def _send_ping_co(): - while True: - await ping_interval.tick() - try: - await channel.asend_message( - self._pch, proto.PingRequest(timestamp=utils.time_ms()) - ) - except utils.aio.duplex_unix.DuplexClosed: - break - - async def _pong_timeout_co(): - await pong_timeout - logger.error("job is unresponsive, killing job", extra=self.logging_extra()) - self._exception = JobExecutorError_Unresponsive() - self._send_kill_signal() - - tasks = [ - asyncio.create_task(_send_ping_co()), - asyncio.create_task(_pong_timeout_co()), - ] - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) - - def logging_extra(self): - extra: dict[str, Any] = { - "pid": self.pid, - } - if self._running_job: - extra["job_id"] = self._running_job.job.id - - return extra diff --git a/livekit-agents/livekit/agents/ipc/proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/proc_lazy_main.py deleted file mode 100644 index be09e7f5a..000000000 --- a/livekit-agents/livekit/agents/ipc/proc_lazy_main.py +++ /dev/null @@ -1,72 +0,0 @@ -import multiprocessing - -if multiprocessing.current_process().name == "job_proc": - import signal - import sys - - # ignore signals in the jobs process (the parent process will handle them) - signal.signal(signal.SIGINT, signal.SIG_IGN) - signal.signal(signal.SIGTERM, signal.SIG_IGN) - - def _no_traceback_excepthook(exc_type, exc_val, traceback): - if isinstance(exc_val, KeyboardInterrupt): - return - sys.__excepthook__(exc_type, exc_val, traceback) - - sys.excepthook = _no_traceback_excepthook - - -def proc_main(args) -> None: - """main function for the job process when using the ProcessJobRunner""" - - # import every package lazily - import asyncio - import logging - - from .. import utils - from ..job import JobProcess - from ..log import logger - from . import channel, job_main, proto - - root_logger = logging.getLogger() - root_logger.setLevel(logging.NOTSET) - - log_cch = utils.aio.duplex_unix._Duplex.open(args.log_cch) - log_handler = job_main.LogQueueHandler(log_cch) - root_logger.addHandler(log_handler) - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.set_debug(args.asyncio_debug) - loop.slow_callback_duration = 0.1 # 100ms - utils.aio.debug.hook_slow_callbacks(2.0) - - cch = utils.aio.duplex_unix._Duplex.open(args.mp_cch) - try: - init_req = channel.recv_message(cch, proto.IPC_MESSAGES) - - assert isinstance( - init_req, proto.InitializeRequest - ), "first message must be InitializeRequest" - - job_proc = JobProcess(start_arguments=args.user_arguments) - logger.info("initializing process", extra={"pid": job_proc.pid}) - args.initialize_process_fnc(job_proc) - logger.info("process initialized", extra={"pid": job_proc.pid}) - channel.send_message(cch, proto.InitializeResponse()) - - main_task = loop.create_task( - job_main._async_main(job_proc, args.job_entrypoint_fnc, cch.detach()), - name="job_proc_main", - ) - while not main_task.done(): - try: - loop.run_until_complete(main_task) - except KeyboardInterrupt: - # ignore the keyboard interrupt, we handle the process shutdown ourselves on the worker process - pass - except (utils.aio.duplex_unix.DuplexClosed, KeyboardInterrupt): - pass - finally: - log_handler.close() - loop.run_until_complete(loop.shutdown_default_executor()) diff --git a/livekit-agents/livekit/agents/ipc/proc_pool.py b/livekit-agents/livekit/agents/ipc/proc_pool.py index d707987ab..25a395f53 100644 --- a/livekit-agents/livekit/agents/ipc/proc_pool.py +++ b/livekit-agents/livekit/agents/ipc/proc_pool.py @@ -8,7 +8,7 @@ from ..job import JobContext, JobExecutorType, JobProcess, RunningJobInfo from ..log import logger from ..utils import aio -from . import proc_job_executor, thread_job_executor +from . import inference_executor, job_proc_executor, job_thread_executor from .job_executor import JobExecutor EventTypes = Literal[ @@ -31,8 +31,11 @@ def __init__( num_idle_processes: int, initialize_timeout: float, close_timeout: float, + inference_executor: inference_executor.InferenceExecutor | None, job_executor_type: JobExecutorType, mp_ctx: BaseContext, + memory_warn_mb: float, + memory_limit_mb: float, loop: asyncio.AbstractEventLoop, ) -> None: super().__init__() @@ -41,9 +44,11 @@ def __init__( self._initialize_process_fnc = initialize_process_fnc self._job_entrypoint_fnc = job_entrypoint_fnc self._close_timeout = close_timeout + self._inf_executor = inference_executor self._initialize_timeout = initialize_timeout self._loop = loop - + self._memory_limit_mb = memory_limit_mb + self._memory_warn_mb = memory_warn_mb self._num_idle_processes = num_idle_processes self._init_sem = asyncio.Semaphore(MAX_CONCURRENT_INITIALIZATIONS) self._proc_needed_sem = asyncio.Semaphore(num_idle_processes) @@ -95,21 +100,30 @@ async def launch_job(self, info: RunningJobInfo) -> None: async def _proc_watch_task(self) -> None: proc: JobExecutor if self._job_executor_type == JobExecutorType.THREAD: - proc = thread_job_executor.ThreadJobExecutor( + proc = job_thread_executor.ThreadJobExecutor( initialize_process_fnc=self._initialize_process_fnc, job_entrypoint_fnc=self._job_entrypoint_fnc, initialize_timeout=self._initialize_timeout, close_timeout=self._close_timeout, + inference_executor=self._inf_executor, + ping_interval=2.5, + high_ping_threshold=0.5, loop=self._loop, ) elif self._job_executor_type == JobExecutorType.PROCESS: - proc = proc_job_executor.ProcJobExecutor( + proc = job_proc_executor.ProcJobExecutor( initialize_process_fnc=self._initialize_process_fnc, job_entrypoint_fnc=self._job_entrypoint_fnc, initialize_timeout=self._initialize_timeout, close_timeout=self._close_timeout, + inference_executor=self._inf_executor, mp_ctx=self._mp_ctx, loop=self._loop, + ping_interval=2.5, + ping_timeout=60, + high_ping_threshold=0.5, + memory_warn_mb=self._memory_warn_mb, + memory_limit_mb=self._memory_limit_mb, ) else: raise ValueError(f"unsupported job executor: {self._job_executor_type}") diff --git a/livekit-agents/livekit/agents/ipc/proto.py b/livekit-agents/livekit/agents/ipc/proto.py index 7dd7c29e3..509964b55 100644 --- a/livekit-agents/livekit/agents/ipc/proto.py +++ b/livekit-agents/livekit/agents/ipc/proto.py @@ -9,11 +9,6 @@ from ..job import JobAcceptArguments, RunningJobInfo from . import channel -PING_INTERVAL = 2.5 -PING_TIMEOUT = 90 -HIGH_PING_THRESHOLD = 0.5 -NO_MESSAGE_TIMEOUT = 15.0 - @dataclass class InitializeRequest: @@ -21,6 +16,25 @@ class InitializeRequest: MSG_ID: ClassVar[int] = 0 + asyncio_debug: bool = False + ping_interval: float = 0 + ping_timeout: float = 0 # if no response, process is considered dead + high_ping_threshold: float = ( + 0 # if ping is higher than this, process is considered unresponsive + ) + + def write(self, b: io.BytesIO) -> None: + channel.write_bool(b, self.asyncio_debug) + channel.write_float(b, self.ping_interval) + channel.write_float(b, self.ping_timeout) + channel.write_float(b, self.high_ping_threshold) + + def read(self, b: io.BytesIO) -> None: + self.asyncio_debug = channel.read_bool(b) + self.ping_interval = channel.read_float(b) + self.ping_timeout = channel.read_float(b) + self.high_ping_threshold = channel.read_float(b) + @dataclass class InitializeResponse: @@ -76,6 +90,7 @@ def write(self, b: io.BytesIO) -> None: channel.write_string(b, accept_args.metadata) channel.write_string(b, self.running_job.url) channel.write_string(b, self.running_job.token) + channel.write_string(b, self.running_job.worker_id) def read(self, b: io.BytesIO) -> None: job = agent.Job() @@ -89,6 +104,7 @@ def read(self, b: io.BytesIO) -> None: job=job, url=channel.read_string(b), token=channel.read_string(b), + worker_id=channel.read_string(b), ) @@ -121,6 +137,50 @@ def read(self, b: io.BytesIO) -> None: self.reason = channel.read_string(b) +@dataclass +class InferenceRequest: + """sent by a subprocess to the main process to request inference""" + + MSG_ID: ClassVar[int] = 7 + method: str = "" + request_id: str = "" + data: bytes = b"" + + def write(self, b: io.BytesIO) -> None: + channel.write_string(b, self.method) + channel.write_string(b, self.request_id) + channel.write_bytes(b, self.data) + + def read(self, b: io.BytesIO) -> None: + self.method = channel.read_string(b) + self.request_id = channel.read_string(b) + self.data = channel.read_bytes(b) + + +@dataclass +class InferenceResponse: + """response to an InferenceRequest""" + + MSG_ID: ClassVar[int] = 8 + request_id: str = "" + data: bytes | None = None + error: str = "" + + def write(self, b: io.BytesIO) -> None: + channel.write_string(b, self.request_id) + channel.write_bool(b, self.data is not None) + if self.data is not None: + channel.write_bytes(b, self.data) + channel.write_string(b, self.error) + + def read(self, b: io.BytesIO) -> None: + self.request_id = channel.read_string(b) + has_data = channel.read_bool(b) + if has_data: + self.data = channel.read_bytes(b) + self.error = channel.read_string(b) + + IPC_MESSAGES = { InitializeRequest.MSG_ID: InitializeRequest, InitializeResponse.MSG_ID: InitializeResponse, @@ -129,4 +189,6 @@ def read(self, b: io.BytesIO) -> None: StartJobRequest.MSG_ID: StartJobRequest, ShutdownRequest.MSG_ID: ShutdownRequest, Exiting.MSG_ID: Exiting, + InferenceRequest.MSG_ID: InferenceRequest, + InferenceResponse.MSG_ID: InferenceResponse, } diff --git a/livekit-agents/livekit/agents/ipc/supervised_proc.py b/livekit-agents/livekit/agents/ipc/supervised_proc.py new file mode 100644 index 000000000..e56119876 --- /dev/null +++ b/livekit-agents/livekit/agents/ipc/supervised_proc.py @@ -0,0 +1,407 @@ +from __future__ import annotations + +import asyncio +import contextlib +import logging +import multiprocessing as mp +import socket +import sys +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass +from multiprocessing.context import BaseContext +from typing import Any + +import psutil + +from ..log import logger +from ..utils import aio, log_exceptions, time_ms +from ..utils.aio import duplex_unix +from . import channel, proto +from .log_queue import LogQueueListener + + +@dataclass +class _ProcOpts: + initialize_timeout: float + close_timeout: float + memory_warn_mb: float + memory_limit_mb: float + ping_interval: float + ping_timeout: float + high_ping_threshold: float + + +class SupervisedProc(ABC): + def __init__( + self, + *, + initialize_timeout: float, + close_timeout: float, + memory_warn_mb: float, + memory_limit_mb: float, + ping_interval: float, + ping_timeout: float, + high_ping_threshold: float, + mp_ctx: BaseContext, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._loop = loop + self._mp_ctx = mp_ctx + self._opts = _ProcOpts( + initialize_timeout=initialize_timeout, + close_timeout=close_timeout, + memory_warn_mb=memory_warn_mb, + memory_limit_mb=memory_limit_mb, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + high_ping_threshold=high_ping_threshold, + ) + + self._exitcode: int | None = None + self._pid: int | None = None + + self._supervise_atask: asyncio.Task[None] | None = None + self._closing = False + self._kill_sent = False + self._initialize_fut = asyncio.Future[None]() + self._lock = asyncio.Lock() + + @abstractmethod + def _create_process( + self, cch: socket.socket, log_cch: socket.socket + ) -> mp.Process: ... + + @abstractmethod + async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None: ... + + @property + def exitcode(self) -> int | None: + return self._exitcode + + @property + def killed(self) -> bool: + return self._kill_sent + + @property + def pid(self) -> int | None: + return self._pid + + @property + def started(self) -> bool: + return self._supervise_atask is not None + + async def start(self) -> None: + """start the supervised process""" + if self.started: + raise RuntimeError("process already started") + + if self._closing: + raise RuntimeError("process is closed") + + await asyncio.shield(self._start()) + + async def _start(self) -> None: + def _add_proc_ctx_log(record: logging.LogRecord) -> None: + extra = self.logging_extra() + for key, value in extra.items(): + setattr(record, key, value) + + async with self._lock: + mp_pch, mp_cch = socket.socketpair() + mp_log_pch, mp_log_cch = socket.socketpair() + + self._pch = await duplex_unix._AsyncDuplex.open(mp_pch) + + log_pch = duplex_unix._Duplex.open(mp_log_pch) + log_listener = LogQueueListener(log_pch, _add_proc_ctx_log) + log_listener.start() + + self._proc = self._create_process(mp_cch, mp_log_cch) + self._proc.start() + mp_log_cch.close() + mp_cch.close() + + self._pid = self._proc.pid + self._join_fut = asyncio.Future[None]() + + def _sync_run(): + self._proc.join() + log_listener.stop() + try: + self._loop.call_soon_threadsafe(self._join_fut.set_result, None) + except RuntimeError: + pass + + thread = threading.Thread(target=_sync_run, name="proc_join_thread") + thread.start() + self._supervise_atask = asyncio.create_task(self._supervise_task()) + + async def join(self) -> None: + """wait for the process to finish""" + if not self.started: + raise RuntimeError("process not started") + + async with self._lock: + if self._supervise_atask: + await asyncio.shield(self._supervise_atask) + + async def initialize(self) -> None: + """initialize the process, this is sending a InitializeRequest message and waiting for a + InitializeResponse with a timeout""" + await channel.asend_message( + self._pch, + proto.InitializeRequest( + asyncio_debug=self._loop.get_debug(), + ping_interval=self._opts.ping_interval, + ping_timeout=self._opts.ping_timeout, + high_ping_threshold=self._opts.high_ping_threshold, + ), + ) + + # wait for the process to become ready + try: + init_res = await asyncio.wait_for( + channel.arecv_message(self._pch, proto.IPC_MESSAGES), + timeout=self._opts.initialize_timeout, + ) + assert isinstance( + init_res, proto.InitializeResponse + ), "first message must be InitializeResponse" + except asyncio.TimeoutError: + self._initialize_fut.set_exception( + asyncio.TimeoutError("process initialization timed out") + ) + logger.error( + "initialization timed out, killing process", extra=self.logging_extra() + ) + self._send_kill_signal() + raise + except Exception as e: # should be channel.ChannelClosed most of the time + self._initialize_fut.set_exception(e) + raise + else: + self._initialize_fut.set_result(None) + + async def aclose(self) -> None: + """attempt to gracefully close the supervised process""" + if not self.started: + return + + self._closing = True + with contextlib.suppress(duplex_unix.DuplexClosed): + await channel.asend_message(self._pch, proto.ShutdownRequest()) + + try: + if self._supervise_atask: + await asyncio.wait_for( + asyncio.shield(self._supervise_atask), + timeout=self._opts.close_timeout, + ) + except asyncio.TimeoutError: + logger.error( + "process did not exit in time, killing process", + extra=self.logging_extra(), + ) + self._send_kill_signal() + + async with self._lock: + if self._supervise_atask: + await asyncio.shield(self._supervise_atask) + + async def kill(self) -> None: + """forcefully kill the supervised process""" + if not self.started: + raise RuntimeError("process not started") + + self._closing = True + self._send_kill_signal() + + async with self._lock: + if self._supervise_atask: + await asyncio.shield(self._supervise_atask) + + def _send_kill_signal(self) -> None: + """forcefully kill the process""" + try: + if not self._proc.is_alive(): + return + except ValueError: + return + + logger.info("killing process", extra=self.logging_extra()) + if sys.platform == "win32": + self._proc.terminate() + else: + self._proc.kill() + + self._kill_sent = True + + @log_exceptions(logger=logger) + async def _supervise_task(self) -> None: + try: + await self._initialize_fut + except asyncio.TimeoutError: + pass # this happens when the initialization takes longer than self._initialize_timeout + except Exception: + pass # initialization failed + + # the process is killed if it doesn't respond to ping requests + pong_timeout = aio.sleep(self._opts.ping_timeout) + + ipc_ch = aio.Chan[channel.Message]() + + main_task = asyncio.create_task(self._main_task(ipc_ch)) + read_ipc_task = asyncio.create_task(self._read_ipc_task(ipc_ch, pong_timeout)) + ping_task = asyncio.create_task(self._ping_pong_task(pong_timeout)) + read_ipc_task.add_done_callback(lambda _: ipc_ch.close()) + + memory_monitor_task: asyncio.Task[None] | None = None + if self._opts.memory_limit_mb > 0 or self._opts.memory_warn_mb > 0: + memory_monitor_task = asyncio.create_task(self._memory_monitor_task()) + + await self._join_fut + self._exitcode = self._proc.exitcode + self._proc.close() + await aio.gracefully_cancel(ping_task, read_ipc_task, main_task) + + if memory_monitor_task is not None: + await aio.gracefully_cancel(memory_monitor_task) + + with contextlib.suppress(duplex_unix.DuplexClosed): + await self._pch.aclose() + + if self._exitcode != 0 and not self._kill_sent: + logger.error( + f"process exited with non-zero exit code {self.exitcode}", + extra=self.logging_extra(), + ) + + @log_exceptions(logger=logger) + async def _read_ipc_task( + self, ipc_ch: aio.Chan[channel.Message], pong_timeout: aio.Sleep + ) -> None: + while True: + try: + msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES) + except duplex_unix.DuplexClosed: + break + + if isinstance(msg, proto.PongResponse): + delay = time_ms() - msg.timestamp + if delay > self._opts.high_ping_threshold * 1000: + logger.warning( + "process is unresponsive", + extra={"delay": delay, **self.logging_extra()}, + ) + + with contextlib.suppress(aio.SleepFinished): + pong_timeout.reset() + + if isinstance(msg, proto.Exiting): + logger.info( + "process exiting", + extra={"reason": msg.reason, **self.logging_extra()}, + ) + + ipc_ch.send_nowait(msg) + + @log_exceptions(logger=logger) + async def _ping_pong_task(self, pong_timeout: aio.Sleep) -> None: + ping_interval = aio.interval(self._opts.ping_interval) + + async def _send_ping_co(): + while True: + await ping_interval.tick() + try: + await channel.asend_message( + self._pch, proto.PingRequest(timestamp=time_ms()) + ) + except duplex_unix.DuplexClosed: + break + + async def _pong_timeout_co(): + await pong_timeout + logger.error( + "process is unresponsive, killing process", extra=self.logging_extra() + ) + self._send_kill_signal() + + tasks = [ + asyncio.create_task(_send_ping_co()), + asyncio.create_task(_pong_timeout_co()), + ] + try: + await asyncio.gather(*tasks) + finally: + await aio.gracefully_cancel(*tasks) + + @log_exceptions(logger=logger) + async def _memory_monitor_task(self) -> None: + """Monitor memory usage and kill the process if it exceeds the limit.""" + while not self._closing and not self._kill_sent: + try: + if not self._pid: + await asyncio.sleep(5) + continue + + # get process memory info + process = psutil.Process(self._pid) + memory_info = process.memory_info() + memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB + + if ( + self._opts.memory_limit_mb > 0 + and memory_mb > self._opts.memory_limit_mb + ): + logger.error( + "process exceeded memory limit, killing process", + extra={ + "memory_usage_mb": memory_mb, + "memory_limit_mb": self._opts.memory_limit_mb, + **self.logging_extra(), + }, + ) + self._send_kill_signal() + elif ( + self._opts.memory_warn_mb > 0 + and memory_mb > self._opts.memory_warn_mb + ): + logger.warning( + "process memory usage is high", + extra={ + "memory_usage_mb": memory_mb, + "memory_warn_mb": self._opts.memory_warn_mb, + "memory_limit_mb": self._opts.memory_limit_mb, + **self.logging_extra(), + }, + ) + + except (psutil.NoSuchProcess, psutil.AccessDenied) as e: + if self._closing or self._kill_sent: + return + + logger.warning( + "Failed to get memory info for process", + extra=self.logging_extra(), + exc_info=e, + ) + # don't bother rechecking if we cannot get process info + return + except Exception: + if self._closing or self._kill_sent: + return + + logger.exception( + "Error in memory monitoring task", + extra=self.logging_extra(), + ) + + await asyncio.sleep(5) # check every 5 seconds + + def logging_extra(self): + extra: dict[str, Any] = { + "pid": self.pid, + } + + return extra diff --git a/livekit-agents/livekit/agents/job.py b/livekit-agents/livekit/agents/job.py index 471ed86c6..b54f8358c 100644 --- a/livekit-agents/livekit/agents/job.py +++ b/livekit-agents/livekit/agents/job.py @@ -15,16 +15,31 @@ from __future__ import annotations import asyncio +import contextvars +import functools import multiprocessing as mp from dataclasses import dataclass from enum import Enum, unique from typing import Any, Callable, Coroutine, Tuple -from livekit import rtc +from livekit import api, rtc from livekit.protocol import agent, models +from .ipc.inference_executor import InferenceExecutor from .log import logger +_JobContextVar = contextvars.ContextVar["JobContext"]("agents_job_context") + + +def get_current_job_context() -> JobContext: + ctx = _JobContextVar.get(None) + if ctx is None: + raise RuntimeError( + "no job context found, are you running this code inside a job entrypoint?" + ) + + return ctx + @unique class JobExecutorType(Enum): @@ -53,6 +68,7 @@ class RunningJobInfo: job: agent.Job url: str token: str + worker_id: str DEFAULT_PARTICIPANT_KINDS: list[rtc.ParticipantKind.ValueType] = [ @@ -70,6 +86,7 @@ def __init__( room: rtc.Room, on_connect: Callable[[], None], on_shutdown: Callable[[str], None], + inference_executor: InferenceExecutor, ) -> None: self._proc = proc self._info = info @@ -87,6 +104,15 @@ def __init__( ] = [] self._participant_tasks = dict[Tuple[str, Callable], asyncio.Task[None]]() self._room.on("participant_connected", self._participant_available) + self._inf_executor = inference_executor + + @property + def inference_executor(self) -> InferenceExecutor: + return self._inf_executor + + @functools.cached_property + def api(self) -> api.LiveKitAPI: + return api.LiveKitAPI() @property def proc(self) -> JobProcess: @@ -98,6 +124,11 @@ def job(self) -> agent.Job: """Returns the current job that the worker is executing.""" return self._info.job + @property + def worker_id(self) -> str: + """Returns the id of the worker.""" + return self._info.worker_id + @property def room(self) -> rtc.Room: """The Room object is the main interface that the worker should interact with. @@ -249,10 +280,14 @@ def on_track_published(pub: rtc.RemoteTrackPublication, _: rtc.RemoteParticipant class JobProcess: - def __init__(self, *, start_arguments: Any | None = None) -> None: + def __init__( + self, + *, + user_arguments: Any | None = None, + ) -> None: self._mp_proc = mp.current_process() self._userdata: dict[str, Any] = {} - self._start_arguments = start_arguments + self._user_arguments = user_arguments @property def pid(self) -> int | None: @@ -263,8 +298,8 @@ def userdata(self) -> dict: return self._userdata @property - def start_arguments(self) -> Any | None: - return self._start_arguments + def user_arguments(self) -> Any | None: + return self._user_arguments class JobRequest: diff --git a/livekit-agents/livekit/agents/llm/__init__.py b/livekit-agents/livekit/agents/llm/__init__.py index 7ba714ce6..d3a06f520 100644 --- a/livekit-agents/livekit/agents/llm/__init__.py +++ b/livekit-agents/livekit/agents/llm/__init__.py @@ -6,6 +6,7 @@ ChatMessage, ChatRole, ) +from .fallback_adapter import AvailabilityChangedEvent, FallbackAdapter from .function_context import ( USE_DOCSTRING, CalledFunction, @@ -14,6 +15,7 @@ FunctionContext, FunctionInfo, TypeInfo, + _create_ai_function_info, ai_callable, ) from .llm import ( @@ -24,6 +26,7 @@ CompletionUsage, LLMCapabilities, LLMStream, + ToolChoice, ) __all__ = [ @@ -49,4 +52,8 @@ "CalledFunction", "USE_DOCSTRING", "LLMCapabilities", + "FallbackAdapter", + "AvailabilityChangedEvent", + "ToolChoice", + "_create_ai_function_info", ] diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index 3cbbcdef9..ccde86bba 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -26,11 +26,59 @@ @dataclass class ChatImage: + """ + ChatImage is used to input images into the ChatContext on supported LLM providers / plugins. + + You may need to consult your LLM provider's documentation on supported URL types. + + ```python + # Pass a VideoFrame directly, which will be automatically converted to a JPEG data URL internally + async for event in rtc.VideoStream(video_track): + chat_image = ChatImage(image=event.frame) + # this instance is now available for your ChatContext + + # Encode your VideoFrame yourself for more control, and pass the result as a data URL (see EncodeOptions for more details) + from livekit.agents.utils.images import encode, EncodeOptions, ResizeOptions + + image_bytes = encode( + event.frame, + EncodeOptions( + format="PNG", + resize_options=ResizeOptions( + width=512, height=512, strategy="scale_aspect_fit" + ), + ), + ) + chat_image = ChatImage( + image=f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}" + ) + + # With an external URL + chat_image = ChatImage(image="https://example.com/image.jpg") + ``` + """ + image: str | rtc.VideoFrame + """ + Either a string URL or a VideoFrame object + """ inference_width: int | None = None + """ + Resizing parameter for rtc.VideoFrame inputs (ignored for URL images) + """ inference_height: int | None = None + """ + Resizing parameter for rtc.VideoFrame inputs (ignored for URL images) + """ + inference_detail: Literal["auto", "high", "low"] = "auto" + """ + Detail parameter for LLM provider, if supported. + + Currently only supported by OpenAI (see https://platform.openai.com/docs/guides/vision?lang=node#low-or-high-fidelity-image-understanding) + """ _cache: dict[Any, Any] = field(default_factory=dict, repr=False, init=False) - """_cache is used by LLM implementations to store a processed version of the image + """ + _cache is used internally by LLM implementations to store a processed version of the image for later use. """ diff --git a/livekit-agents/livekit/agents/llm/fallback_adapter.py b/livekit-agents/livekit/agents/llm/fallback_adapter.py new file mode 100644 index 000000000..fd5242e4d --- /dev/null +++ b/livekit-agents/livekit/agents/llm/fallback_adapter.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import asyncio +import dataclasses +import time +from dataclasses import dataclass +from typing import AsyncIterable, Literal, Union + +from livekit.agents._exceptions import APIConnectionError, APIError + +from ..log import logger +from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions +from .chat_context import ChatContext +from .function_context import FunctionContext +from .llm import LLM, ChatChunk, LLMStream, ToolChoice + +DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions( + max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout +) + + +@dataclass +class _LLMStatus: + available: bool + recovering_task: asyncio.Task | None + + +@dataclass +class AvailabilityChangedEvent: + llm: LLM + available: bool + + +class FallbackAdapter( + LLM[Literal["llm_availability_changed"]], +): + def __init__( + self, + llm: list[LLM], + *, + attempt_timeout: float = 10.0, + max_retry_per_llm: int = 1, + retry_interval: float = 5, + ) -> None: + if len(llm) < 1: + raise ValueError("at least one LLM instance must be provided.") + + super().__init__() + + self._llm_instances = llm + self._attempt_timeout = attempt_timeout + self._max_retry_per_llm = max_retry_per_llm + self._retry_interval = retry_interval + + self._status = [ + _LLMStatus(available=True, recovering_task=None) + for _ in self._llm_instances + ] + + def chat( + self, + *, + chat_ctx: ChatContext, + conn_options: APIConnectOptions = DEFAULT_FALLBACK_API_CONNECT_OPTIONS, + fnc_ctx: FunctionContext | None = None, + temperature: float | None = None, + n: int | None = 1, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] + | None = None, + ) -> "LLMStream": + return FallbackLLMStream( + llm=self, + conn_options=conn_options, + chat_ctx=chat_ctx, + fnc_ctx=fnc_ctx, + temperature=temperature, + n=n, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, + ) + + +class FallbackLLMStream(LLMStream): + def __init__( + self, + *, + llm: FallbackAdapter, + conn_options: APIConnectOptions, + chat_ctx: ChatContext, + fnc_ctx: FunctionContext | None, + temperature: float | None, + n: int | None, + parallel_tool_calls: bool | None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] + | None = None, + ) -> None: + super().__init__( + llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options + ) + self._fallback_adapter = llm + self._temperature = temperature + self._n = n + self._parallel_tool_calls = parallel_tool_calls + self._tool_choice = tool_choice + + async def _try_generate( + self, *, llm: LLM, recovering: bool = False + ) -> AsyncIterable[ChatChunk]: + try: + async with llm.chat( + chat_ctx=self._chat_ctx, + fnc_ctx=self._fnc_ctx, + temperature=self._temperature, + n=self._n, + parallel_tool_calls=self._parallel_tool_calls, + tool_choice=self._tool_choice, + conn_options=dataclasses.replace( + self._conn_options, + max_retry=self._fallback_adapter._max_retry_per_llm, + timeout=self._fallback_adapter._attempt_timeout, + retry_interval=self._fallback_adapter._retry_interval, + ), + ) as stream: + async for chunk in stream: + yield chunk + + except asyncio.TimeoutError: + if recovering: + logger.warning(f"{llm.label} recovery timed out") + raise + + logger.warning( + f"{llm.label} timed out, switching to next LLM", + ) + + raise + except APIError as e: + if recovering: + logger.warning( + f"{llm.label} recovery failed", + exc_info=e, + ) + raise + + logger.warning( + f"{llm.label} failed, switching to next LLM", + exc_info=e, + ) + raise + except Exception: + if recovering: + logger.exception( + f"{llm.label} recovery unexpected error", + ) + raise + + logger.exception( + f"{llm.label} unexpected error, switching to next LLM", + ) + raise + + def _try_recovery(self, llm: LLM) -> None: + llm_status = self._fallback_adapter._status[ + self._fallback_adapter._llm_instances.index(llm) + ] + if llm_status.recovering_task is None or llm_status.recovering_task.done(): + + async def _recover_llm_task(llm: LLM) -> None: + try: + async for _ in self._try_generate(llm=llm, recovering=True): + pass + + llm_status.available = True + logger.info(f"llm.FallbackAdapter, {llm.label} recovered") + self._fallback_adapter.emit( + "llm_availability_changed", + AvailabilityChangedEvent(llm=llm, available=True), + ) + except Exception: + return + + llm_status.recovering_task = asyncio.create_task(_recover_llm_task(llm)) + + async def _run(self) -> None: + start_time = time.time() + + all_failed = all( + not llm_status.available for llm_status in self._fallback_adapter._status + ) + if all_failed: + logger.error("all LLMs are unavailable, retrying..") + + for i, llm in enumerate(self._fallback_adapter._llm_instances): + llm_status = self._fallback_adapter._status[i] + if llm_status.available or all_failed: + chunk_sent = False + try: + async for synthesized_audio in self._try_generate( + llm=llm, recovering=False + ): + chunk_sent = True + self._event_ch.send_nowait(synthesized_audio) + + return + except Exception: # exceptions already logged inside _try_synthesize + if llm_status.available: + llm_status.available = False + self._fallback_adapter.emit( + "llm_availability_changed", + AvailabilityChangedEvent(llm=llm, available=False), + ) + + if chunk_sent: + raise + + self._try_recovery(llm) + + raise APIConnectionError( + "all LLMs failed (%s) after %s seconds" + % ( + [llm.label for llm in self._fallback_adapter._llm_instances], + time.time() - start_time, + ) + ) diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index 9564c3a1c..59604fc8d 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -18,9 +18,11 @@ import enum import functools import inspect +import json +import types import typing from dataclasses import dataclass -from typing import Any, Callable, Tuple +from typing import Any, Callable, Optional, Tuple from ..log import logger @@ -103,7 +105,7 @@ class CalledFunction: def ai_callable( *, name: str | None = None, - description: str | _UseDocMarker | None = None, + description: str | _UseDocMarker = USE_DOCSTRING, auto_retry: bool = False, ) -> Callable: def deco(f): @@ -125,7 +127,7 @@ def ai_callable( self, *, name: str | None = None, - description: str | _UseDocMarker | None = None, + description: str | _UseDocMarker = USE_DOCSTRING, auto_retry: bool = True, ) -> Callable: def deco(f): @@ -168,15 +170,13 @@ def _register_ai_function(self, fnc: Callable) -> None: ) desc = type_info.description if type_info else "" - choices = type_info.choices if type_info else None + choices = type_info.choices if type_info else () - is_optional, optional_inner = _is_optional_type(inner_th) - if is_optional: - # when the type is optional, only the inner type is relevant - # the argument info for default would be None - inner_th = optional_inner - - if issubclass(inner_th, enum.Enum) and not choices: + if ( + isinstance(inner_th, type) + and issubclass(inner_th, enum.Enum) + and not choices + ): # the enum must be a str or int (and at least one value) # this is verified by is_type_supported choices = tuple([item.value for item in inner_th]) @@ -223,7 +223,8 @@ def _extract_types(annotation: type) -> tuple[type, TypeInfo | None]: is_optional, optional_inner = _is_optional_type(annotation) if is_optional: - return _extract_types(optional_inner) + inner_type, info = _extract_types(optional_inner) + return Optional[inner_type], info # type: ignore return annotation, None @@ -242,19 +243,17 @@ def _extract_types(annotation: type) -> tuple[type, TypeInfo | None]: def _set_metadata( f: Callable, name: str | None = None, - desc: str | _UseDocMarker | None = None, + desc: str | _UseDocMarker = USE_DOCSTRING, auto_retry: bool = False, ) -> None: - if desc is None: - desc = "" - if isinstance(desc, _UseDocMarker): - desc = inspect.getdoc(f) - if desc is None: + docstring = inspect.getdoc(f) + if docstring is None: raise ValueError( f"missing docstring for function {f.__name__}, " "use explicit description or provide docstring" ) + desc = docstring metadata = _AIFncMetadata( name=name or f.__name__, description=desc, auto_retry=auto_retry @@ -291,17 +290,108 @@ def is_type_supported(t: type) -> bool: def _is_optional_type(typ) -> Tuple[bool, Any]: """return is_optional, inner_type""" origin = typing.get_origin(typ) + if origin is None or origin is list: + return False, typ - if origin in {typing.Union, getattr(__builtins__, "UnionType", typing.Union)}: + if origin in {typing.Union, getattr(types, "UnionType", typing.Union)}: args = typing.get_args(typ) is_optional = type(None) in args + non_none_args = [a for a in args if a is not type(None)] + if is_optional and len(non_none_args) == 1: + # Exactly one non-None type + None means optional + return True, non_none_args[0] - inner_arg = None - for arg in args: - if arg is not type(None): - inner_arg = arg - break + return False, None - return is_optional, inner_arg - return False, None +def _create_ai_function_info( + fnc_ctx: FunctionContext, + tool_call_id: str, + fnc_name: str, + raw_arguments: str, # JSON string +) -> FunctionCallInfo: + if fnc_name not in fnc_ctx.ai_functions: + raise ValueError(f"AI function {fnc_name} not found") + + parsed_arguments: dict[str, Any] = {} + try: + if raw_arguments: # ignore empty string + parsed_arguments = json.loads(raw_arguments) + except json.JSONDecodeError: + raise ValueError( + f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}" + ) + + fnc_info = fnc_ctx.ai_functions[fnc_name] + + # Ensure all necessary arguments are present and of the correct type. + sanitized_arguments: dict[str, Any] = {} + for arg_info in fnc_info.arguments.values(): + if arg_info.name not in parsed_arguments: + if arg_info.default is inspect.Parameter.empty: + raise ValueError( + f"AI function {fnc_name} missing required argument {arg_info.name}" + ) + continue + + arg_value = parsed_arguments[arg_info.name] + is_optional, inner_th = _is_optional_type(arg_info.type) + + if typing.get_origin(inner_th) is not None: + if not isinstance(arg_value, list): + raise ValueError( + f"AI function {fnc_name} argument {arg_info.name} should be a list" + ) + + inner_type = typing.get_args(inner_th)[0] + sanitized_value = [ + _sanitize_primitive( + value=v, + expected_type=inner_type, + choices=arg_info.choices, + ) + for v in arg_value + ] + else: + sanitized_value = _sanitize_primitive( + value=arg_value, + expected_type=inner_th, + choices=arg_info.choices, + ) + + sanitized_arguments[arg_info.name] = sanitized_value + + return FunctionCallInfo( + tool_call_id=tool_call_id, + raw_arguments=raw_arguments, + function_info=fnc_info, + arguments=sanitized_arguments, + ) + + +def _sanitize_primitive( + *, value: Any, expected_type: type, choices: tuple | None +) -> Any: + if expected_type is str: + if not isinstance(value, str): + raise ValueError(f"expected str, got {type(value)}") + elif expected_type in (int, float): + if not isinstance(value, (int, float)): + raise ValueError(f"expected number, got {type(value)}") + + if expected_type is int: + if value % 1 != 0: + raise ValueError("expected int, got float") + + value = int(value) + elif expected_type is float: + value = float(value) + + elif expected_type is bool: + if not isinstance(value, bool): + raise ValueError(f"expected bool, got {type(value)}") + + if choices and value not in choices: + raise ValueError(f"invalid value {value}, not in {choices}") + + return value diff --git a/livekit-agents/livekit/agents/llm/llm.py b/livekit-agents/livekit/agents/llm/llm.py index fa2db5a8b..099e3139c 100644 --- a/livekit-agents/livekit/agents/llm/llm.py +++ b/livekit-agents/livekit/agents/llm/llm.py @@ -4,13 +4,24 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, AsyncIterable, AsyncIterator, Literal +from types import TracebackType +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Generic, + Literal, + TypeVar, + Union, +) from livekit import rtc +from livekit.agents._exceptions import APIConnectionError, APIError from .. import utils from ..log import logger from ..metrics import LLMMetrics +from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from ..utils import aio from . import function_context from .chat_context import ChatContext, ChatRole @@ -48,27 +59,60 @@ class ChatChunk: usage: CompletionUsage | None = None -class LLM(ABC, rtc.EventEmitter[Literal["metrics_collected"]]): +@dataclass +class ToolChoice: + type: Literal["function"] + name: str + + +TEvent = TypeVar("TEvent") + + +class LLM( + ABC, + rtc.EventEmitter[Union[Literal["metrics_collected"], TEvent]], + Generic[TEvent], +): def __init__(self) -> None: super().__init__() self._capabilities = LLMCapabilities() self._label = f"{type(self).__module__}.{type(self).__name__}" + @property + def label(self) -> str: + return self._label + @abstractmethod def chat( self, *, chat_ctx: ChatContext, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, fnc_ctx: function_context.FunctionContext | None = None, temperature: float | None = None, n: int | None = None, parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] + | None = None, ) -> "LLMStream": ... @property def capabilities(self) -> LLMCapabilities: return self._capabilities + async def aclose(self) -> None: ... + + async def __aenter__(self) -> LLM: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.aclose() + class LLMStream(ABC): def __init__( @@ -77,10 +121,12 @@ def __init__( *, chat_ctx: ChatContext, fnc_ctx: function_context.FunctionContext | None, + conn_options: APIConnectOptions, ) -> None: self._llm = llm self._chat_ctx = chat_ctx self._fnc_ctx = fnc_ctx + self._conn_options = conn_options self._event_ch = aio.Chan[ChatChunk]() self._event_aiter, monitor_aiter = aio.itertools.tee(self._event_ch, 2) @@ -95,7 +141,30 @@ def __init__( self._function_tasks = set[asyncio.Task[Any]]() @abstractmethod - async def _main_task(self) -> None: ... + async def _run(self) -> None: ... + + async def _main_task(self) -> None: + for i in range(self._conn_options.max_retry + 1): + try: + return await self._run() + except APIError as e: + if self._conn_options.max_retry == 0 or not e.retryable: + raise + elif i == self._conn_options.max_retry: + raise APIConnectionError( + f"failed to generate LLM completion after {self._conn_options.max_retry + 1} attempts", + ) from e + else: + logger.warning( + f"failed to generate LLM completion, retrying in {self._conn_options.retry_interval}s", + exc_info=e, + extra={ + "llm": self._llm._label, + "attempt": i + 1, + }, + ) + + await asyncio.sleep(self._conn_options.retry_interval) @utils.log_exceptions(logger=logger) async def _metrics_monitor_task( @@ -174,3 +243,14 @@ async def __anext__(self) -> ChatChunk: def __aiter__(self) -> AsyncIterator[ChatChunk]: return self + + async def __aenter__(self) -> LLMStream: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.aclose() diff --git a/livekit-agents/livekit/agents/metrics/__init__.py b/livekit-agents/livekit/agents/metrics/__init__.py index 5a12ef1c9..a61f430e8 100644 --- a/livekit-agents/livekit/agents/metrics/__init__.py +++ b/livekit-agents/livekit/agents/metrics/__init__.py @@ -12,7 +12,6 @@ TTSMetrics, VADMetrics, ) -from .periodic_collector import PeriodicCollector from .usage_collector import UsageCollector, UsageSummary from .utils import log_metrics @@ -31,6 +30,5 @@ "TTSMetrics", "UsageSummary", "UsageCollector", - "PeriodicCollector", "log_metrics", ] diff --git a/livekit-agents/livekit/agents/metrics/base.py b/livekit-agents/livekit/agents/metrics/base.py index 78d09e4f2..d524b02b8 100644 --- a/livekit-agents/livekit/agents/metrics/base.py +++ b/livekit-agents/livekit/agents/metrics/base.py @@ -108,11 +108,17 @@ class MultimodalLLMError(Error): @dataclass class MultimodalLLMMetrics(LLMMetrics): + @dataclass + class CachedTokenDetails: + text_tokens: int + audio_tokens: int + @dataclass class InputTokenDetails: cached_tokens: int text_tokens: int audio_tokens: int + cached_tokens_details: MultimodalLLMMetrics.CachedTokenDetails @dataclass class OutputTokenDetails: diff --git a/livekit-agents/livekit/agents/multimodal/__init__.py b/livekit-agents/livekit/agents/multimodal/__init__.py index d165c082a..f741e168a 100644 --- a/livekit-agents/livekit/agents/multimodal/__init__.py +++ b/livekit-agents/livekit/agents/multimodal/__init__.py @@ -1,3 +1,13 @@ -from .multimodal_agent import AgentTranscriptionOptions, MultimodalAgent +from .multimodal_agent import ( + AgentTranscriptionOptions, + MultimodalAgent, + _RealtimeAPI, + _RealtimeAPISession, +) -__all__ = ["MultimodalAgent", "AgentTranscriptionOptions"] +__all__ = [ + "MultimodalAgent", + "AgentTranscriptionOptions", + "_RealtimeAPI", + "_RealtimeAPISession", +] diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index d694100a0..f02bb2e64 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -2,7 +2,17 @@ import asyncio from dataclasses import dataclass -from typing import Callable, Literal, Protocol +from typing import ( + Any, + AsyncIterable, + Callable, + Literal, + Optional, + Protocol, + TypeVar, + Union, + overload, +) import aiohttp from livekit import rtc @@ -28,6 +38,76 @@ ] +class _InputTranscriptionProto(Protocol): + item_id: str + """id of the item""" + transcript: str + """transcript of the input audio""" + + +class _ContentProto(Protocol): + response_id: str + item_id: str + output_index: int + content_index: int + text: str + audio: list[rtc.AudioFrame] + text_stream: AsyncIterable[str] + audio_stream: AsyncIterable[rtc.AudioFrame] + content_type: Literal["text", "audio"] + + +class _CapabilitiesProto(Protocol): + supports_truncate: bool + + +class _RealtimeAPI(Protocol): + """Realtime API protocol""" + + @property + def capabilities(self) -> _CapabilitiesProto: ... + def session( + self, + *, + chat_ctx: llm.ChatContext | None = None, + fnc_ctx: llm.FunctionContext | None = None, + ) -> _RealtimeAPISession: + """ + Create a new realtime session with the given chat and function contexts. + """ + pass + + +T = TypeVar("T", bound=Callable[..., Any]) + + +class _RealtimeAPISession(Protocol): + async def set_chat_ctx(self, ctx: llm.ChatContext) -> None: ... + @overload + def on(self, event: str, callback: None = None) -> Callable[[T], T]: ... + @overload + def on(self, event: str, callback: T) -> T: ... + def on( + self, event: str, callback: Optional[T] = None + ) -> Union[T, Callable[[T], T]]: ... + + def _push_audio(self, frame: rtc.AudioFrame) -> None: ... + @property + def fnc_ctx(self) -> llm.FunctionContext | None: ... + @fnc_ctx.setter + def fnc_ctx(self, value: llm.FunctionContext | None) -> None: ... + def chat_ctx_copy(self) -> llm.ChatContext: ... + def _recover_from_text_response(self, item_id: str) -> None: ... + def _update_conversation_item_content( + self, + item_id: str, + content: llm.ChatContent | list[llm.ChatContent] | None = None, + ) -> None: ... + def _truncate_conversation_item( + self, item_id: str, content_index: int, audio_end_ms: int + ) -> None: ... + + @dataclass(frozen=True) class AgentTranscriptionOptions: user_transcription: bool = True @@ -50,9 +130,6 @@ class AgentTranscriptionOptions: representing the hyphenated parts of the word.""" -class S2SModel(Protocol): ... - - @dataclass(frozen=True) class _ImplOptions: transcription: AgentTranscriptionOptions @@ -62,20 +139,33 @@ class MultimodalAgent(utils.EventEmitter[EventTypes]): def __init__( self, *, - model: S2SModel, + model: _RealtimeAPI, vad: vad.VAD | None = None, chat_ctx: llm.ChatContext | None = None, fnc_ctx: llm.FunctionContext | None = None, transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(), + max_text_response_retries: int = 5, loop: asyncio.AbstractEventLoop | None = None, ): + """Create a new MultimodalAgent. + + Args: + model: RealtimeAPI instance. + vad: Voice Activity Detection (VAD) instance. + chat_ctx: Chat context for the assistant. + fnc_ctx: Function context for the assistant. + transcription: Options for assistant transcription. + max_text_response_retries: Maximum number of retries to recover + from text responses to audio mode. OpenAI's realtime API has a + chance to return text responses instead of audio if the chat + context includes text system or assistant messages. The agent will + attempt to recover to audio mode by deleting the text response + and appending an empty audio message to the conversation. + loop: Event loop to use. Default to asyncio.get_event_loop(). + """ super().__init__() self._loop = loop or asyncio.get_event_loop() - from livekit.plugins.openai import realtime - - assert isinstance(model, realtime.RealtimeModel) - self._model = model self._vad = vad self._chat_ctx = chat_ctx @@ -99,6 +189,9 @@ def __init__( self._update_state_task: asyncio.Task | None = None self._http_session: aiohttp.ClientSession | None = None + self._text_response_retries = 0 + self._max_text_response_retries = max_text_response_retries + @property def vad(self) -> vad.VAD | None: return self._vad @@ -157,16 +250,8 @@ async def _init_and_start(): # Schedule the initialization and start task asyncio.create_task(_init_and_start()) - from livekit.plugins.openai import realtime - @self._session.on("response_content_added") - def _on_content_added(message: realtime.RealtimeContent): - if message.content_type == "text": - logger.warning( - "The realtime API returned a text content part, which is not supported" - ) - return - + def _on_content_added(message: _ContentProto): tr_fwd = transcription.TTSSegmentsForwarder( room=self._room, participant=self._room.local_participant, @@ -184,6 +269,31 @@ def _on_content_added(message: realtime.RealtimeContent): audio_stream=message.audio_stream, ) + @self._session.on("response_content_done") + def _response_content_done(message: _ContentProto): + if message.content_type == "text": + if self._text_response_retries >= self._max_text_response_retries: + raise RuntimeError( + f"The OpenAI Realtime API returned a text response " + f"after {self._max_text_response_retries} retries. " + f"Please try to reduce the number of text system or " + f"assistant messages in the chat context." + ) + + self._text_response_retries += 1 + logger.warning( + "The OpenAI Realtime API returned a text response instead of audio. " + "Attempting to recover to audio mode...", + extra={ + "item_id": message.item_id, + "text": message.text, + "retries": self._text_response_retries, + }, + ) + self._session._recover_from_text_response(message.item_id) + else: + self._text_response_retries = 0 + @self._session.on("input_speech_committed") def _input_speech_committed(): self._stt_forwarder.update( @@ -194,9 +304,7 @@ def _input_speech_committed(): ) @self._session.on("input_speech_transcription_completed") - def _input_speech_transcription_completed( - ev: realtime.InputTranscriptionCompleted, - ): + def _input_speech_transcription_completed(ev: _InputTranscriptionProto): self._stt_forwarder.update( stt.SpeechEvent( type=stt.SpeechEventType.FINAL_TRANSCRIPT, @@ -206,7 +314,8 @@ def _input_speech_transcription_completed( user_msg = ChatMessage.create( text=ev.transcript, role="user", id=ev.item_id ) - self._session._update_converstation_item_content( + + self._session._update_conversation_item_content( ev.item_id, user_msg.content ) @@ -223,11 +332,14 @@ def _input_speech_started(): if self._playing_handle is not None and not self._playing_handle.done(): self._playing_handle.interrupt() - self._session.conversation.item.truncate( - item_id=self._playing_handle.item_id, - content_index=self._playing_handle.content_index, - audio_end_ms=int(self._playing_handle.audio_samples / 24000 * 1000), - ) + if self._model.capabilities.supports_truncate: + self._session._truncate_conversation_item( + item_id=self._playing_handle.item_id, + content_index=self._playing_handle.content_index, + audio_end_ms=int( + self._playing_handle.audio_samples / 24000 * 1000 + ), + ) @self._session.on("input_speech_stopped") def _input_speech_stopped(): @@ -288,9 +400,10 @@ def _on_playout_stopped(interrupted: bool) -> None: role="assistant", id=self._playing_handle.item_id, ) - self._session._update_converstation_item_content( - self._playing_handle.item_id, msg.content - ) + if self._model.capabilities.supports_truncate: + self._session._update_conversation_item_content( + self._playing_handle.item_id, msg.content + ) if interrupted: self.emit("agent_speech_interrupted", msg) @@ -324,7 +437,7 @@ def _on_playout_stopped(interrupted: bool) -> None: ) async for frame in self._input_audio_ch: for f in bstream.write(frame.data.tobytes()): - self._session.input_audio_buffer.append(f) + self._session._push_audio(f) def _on_participant_connected(self, participant: rtc.RemoteParticipant): if self._linked_participant is None: diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 5ec3b4456..7379261b3 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -12,6 +12,7 @@ Callable, Literal, Optional, + Protocol, Union, ) @@ -70,6 +71,7 @@ def __init__(self, assistant: "VoicePipelineAgent", llm_stream: LLMStream) -> No self._assistant = assistant self._metadata = dict[str, Any]() self._llm_stream = llm_stream + self._extra_chat_messages: list[ChatMessage] = [] @staticmethod def get_current() -> "AgentCallContext": @@ -79,6 +81,10 @@ def get_current() -> "AgentCallContext": def agent(self) -> "VoicePipelineAgent": return self._assistant + @property + def chat_ctx(self) -> ChatContext: + return self._llm_stream.chat_ctx + def store_metadata(self, key: str, value: Any) -> None: self._metadata[key] = value @@ -88,6 +94,14 @@ def get_metadata(self, key: str, default: Any = None) -> Any: def llm_stream(self) -> LLMStream: return self._llm_stream + def add_extra_chat_message(self, message: ChatMessage) -> None: + """Append chat message to the end of function outputs for the answer LLM call""" + self._extra_chat_messages.append(message) + + @property + def extra_chat_messages(self) -> list[ChatMessage]: + return self._extra_chat_messages + def _default_before_llm_cb( agent: VoicePipelineAgent, chat_ctx: ChatContext @@ -118,6 +132,7 @@ class _ImplOptions: int_speech_duration: float int_min_words: int min_endpointing_delay: float + max_endpointing_delay: float max_nested_fnc_calls: int preemptive_synthesis: bool before_llm_cb: BeforeLLMCallback @@ -148,6 +163,14 @@ class AgentTranscriptionOptions: representing the hyphenated parts of the word.""" +class _TurnDetector(Protocol): + # When endpoint probability is below this threshold we think the user is not finished speaking + # so we will use a long delay + def unlikely_threshold(self) -> float: ... + def supports_language(self, language: str | None) -> bool: ... + async def predict_end_of_turn(self, chat_ctx: ChatContext) -> float: ... + + class VoicePipelineAgent(utils.EventEmitter[EventTypes]): """ A pipeline agent (VAD + STT + LLM + TTS) implementation. @@ -163,12 +186,14 @@ def __init__( stt: stt.STT, llm: LLM, tts: tts.TTS, + turn_detector: _TurnDetector | None = None, chat_ctx: ChatContext | None = None, fnc_ctx: FunctionContext | None = None, allow_interruptions: bool = True, interrupt_speech_duration: float = 0.5, interrupt_min_words: int = 0, min_endpointing_delay: float = 0.5, + max_endpointing_delay: float = 6.0, max_nested_fnc_calls: int = 1, preemptive_synthesis: bool = False, transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(), @@ -226,6 +251,7 @@ def __init__( int_speech_duration=interrupt_speech_duration, int_min_words=interrupt_min_words, min_endpointing_delay=min_endpointing_delay, + max_endpointing_delay=max_endpointing_delay, max_nested_fnc_calls=max_nested_fnc_calls, preemptive_synthesis=preemptive_synthesis, transcription=transcription, @@ -253,6 +279,7 @@ def __init__( ) self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts + self._turn_detector = turn_detector self._chat_ctx = chat_ctx or ChatContext() self._fnc_ctx = fnc_ctx self._started, self._closed = False, False @@ -271,8 +298,10 @@ def __init__( self._deferred_validation = _DeferredReplyValidation( self._validate_reply_if_possible, - self._opts.min_endpointing_delay, - loop=self._loop, + min_endpointing_delay=self._opts.min_endpointing_delay, + max_endpointing_delay=self._opts.max_endpointing_delay, + turn_detector=self._turn_detector, + agent=self, ) self._speech_q: list[SpeechHandle] = [] @@ -407,7 +436,7 @@ async def say( *, allow_interruptions: bool = True, add_to_chat_ctx: bool = True, - ) -> None: + ) -> SpeechHandle: """ Play a speech source through the voice assistant. @@ -416,15 +445,77 @@ async def say( It can be a string, an LLMStream, or an asynchronous iterable of strings. allow_interruptions: Whether to allow interruptions during the speech playback. add_to_chat_ctx: Whether to add the speech to the chat context. + + Returns: + The speech handle for the speech that was played, can be used to + wait for the speech to finish. """ await self._track_published_fut + call_ctx = None + fnc_source: str | AsyncIterable[str] | None = None + if add_to_chat_ctx: + try: + call_ctx = AgentCallContext.get_current() + except LookupError: + # no active call context, ignore + pass + else: + if isinstance(source, LLMStream): + logger.warning( + "LLMStream will be ignored for function call chat context" + ) + elif isinstance(source, AsyncIterable): + source, fnc_source = utils.aio.itertools.tee(source, 2) # type: ignore + else: + fnc_source = source + new_handle = SpeechHandle.create_assistant_speech( allow_interruptions=allow_interruptions, add_to_chat_ctx=add_to_chat_ctx ) synthesis_handle = self._synthesize_agent_speech(new_handle.id, source) new_handle.initialize(source=source, synthesis_handle=synthesis_handle) - self._add_speech_for_playout(new_handle) + + if self._playing_speech and not self._playing_speech.nested_speech_done: + self._playing_speech.add_nested_speech(new_handle) + else: + self._add_speech_for_playout(new_handle) + + # add the speech to the function call context if needed + if call_ctx is not None and fnc_source is not None: + if isinstance(fnc_source, AsyncIterable): + text = "" + async for chunk in fnc_source: + text += chunk + else: + text = fnc_source + + call_ctx.add_extra_chat_message( + ChatMessage.create(text=text, role="assistant") + ) + logger.debug( + "added speech to function call chat context", + extra={"text": text}, + ) + + return new_handle + + def interrupt(self, interrupt_all: bool = True) -> None: + """Interrupt the current speech + + Args: + interrupt_all: Whether to interrupt all pending speech + """ + if interrupt_all: + # interrupt all pending speech + if self._pending_agent_reply is not None: + self._pending_agent_reply.cancel(cancel_nested=True) + for speech in self._speech_q: + speech.cancel(cancel_nested=True) + + # interrupt the playing speech + if self._playing_speech is not None: + self._playing_speech.cancel() def _update_state(self, state: AgentState, delay: float = 0.0): """Set the current state of the agent""" @@ -520,7 +611,9 @@ def _on_final_transcript(self, ev: SpeechEvent) -> None: if self._playing_speech is None or self._playing_speech.allow_interruptions: self._synthesize_agent_reply() - self._deferred_validation.on_human_final_transcript(new_transcript) + self._deferred_validation.on_human_final_transcript( + new_transcript, ev.alternatives[0].language + ) words = self._opts.transcription.word_tokenizer.tokenize(text=new_transcript) if len(words) >= 3: @@ -642,6 +735,11 @@ async def _synthesize_answer_task( not playing_speech.user_question or playing_speech.user_committed ) and not playing_speech.speech_committed: # the speech is playing but not committed yet, add it to the chat context for this new reply synthesis + # First add the previous function call message if any + if playing_speech.extra_tools_messages: + copied_ctx.messages.extend(playing_speech.extra_tools_messages) + + # Then add the previous assistant message copied_ctx.messages.append( ChatMessage.create( text=playing_speech.synthesis_handle.tts_forwarder.played_text, @@ -649,6 +747,9 @@ async def _synthesize_answer_task( ) ) + # we want to add this question even if it's empty. during false positive interruptions, + # adding an empty user message gives the LLM context so it could continue from where + # it had been interrupted. copied_ctx.messages.append( ChatMessage.create(text=handle.user_question, role="user") ) @@ -745,11 +846,64 @@ def _commit_user_question_if_needed() -> None: speech_handle.source.function_calls ) - extra_tools_messages = [] # additional messages from the functions to add to the context if needed + message_id_committed: str | None = None + if ( + collected_text + and speech_handle.add_to_chat_ctx + and (not user_question or speech_handle.user_committed) + ): + if speech_handle.extra_tools_messages: + if speech_handle.fnc_text_message_id is not None: + # there is a message alongside the function calls + msgs = self._chat_ctx.messages + if msgs and msgs[-1].id == speech_handle.fnc_text_message_id: + # replace it with the tool call message if it's the last in the ctx + msgs.pop() + elif speech_handle.extra_tools_messages[0].tool_calls: + # remove the content of the tool call message + speech_handle.extra_tools_messages[0].content = "" + self._chat_ctx.messages.extend(speech_handle.extra_tools_messages) + + if interrupted: + collected_text += "..." + + msg = ChatMessage.create(text=collected_text, role="assistant") + self._chat_ctx.messages.append(msg) + message_id_committed = msg.id + speech_handle.mark_speech_committed() + + if interrupted: + self.emit("agent_speech_interrupted", msg) + else: + self.emit("agent_speech_committed", msg) + + logger.debug( + "committed agent speech", + extra={ + "agent_transcript": collected_text, + "interrupted": interrupted, + "speech_id": speech_handle.id, + }, + ) + + async def _execute_function_calls() -> None: + nonlocal interrupted, collected_text + + # if the answer is using tools, execute the functions and automatically generate + # a response to the user question from the returned values + if not is_using_tools or interrupted: + return + + if speech_handle.fnc_nested_depth >= self._opts.max_nested_fnc_calls: + logger.warning( + "max function calls nested depth reached", + extra={ + "speech_id": speech_handle.id, + "fnc_nested_depth": speech_handle.fnc_nested_depth, + }, + ) + return - # if the answer is using tools, execute the functions and automatically generate - # a response to the user question from the returned values - if is_using_tools and not interrupted: assert isinstance(speech_handle.source, LLMStream) assert ( not user_question or speech_handle.user_committed @@ -757,118 +911,145 @@ def _commit_user_question_if_needed() -> None: llm_stream = speech_handle.source - if collected_text: - msg = ChatMessage.create(text=collected_text, role="assistant") - self._chat_ctx.messages.append(msg) - - speech_handle.mark_speech_committed() - self.emit("agent_speech_committed", msg) - # execute functions call_ctx = AgentCallContext(self, llm_stream) tk = _CallContextVar.set(call_ctx) new_function_calls = llm_stream.function_calls - for i in range(self._opts.max_nested_fnc_calls): - self.emit("function_calls_collected", new_function_calls) + self.emit("function_calls_collected", new_function_calls) - called_fncs = [] - for fnc in new_function_calls: - called_fnc = fnc.execute() - called_fncs.append(called_fnc) - logger.debug( - "executing ai function", + called_fncs = [] + for fnc in new_function_calls: + called_fnc = fnc.execute() + called_fncs.append(called_fnc) + logger.debug( + "executing ai function", + extra={ + "function": fnc.function_info.name, + "speech_id": speech_handle.id, + }, + ) + try: + await called_fnc.task + except Exception as e: + logger.exception( + "error executing ai function", extra={ "function": fnc.function_info.name, "speech_id": speech_handle.id, }, + exc_info=e, ) - try: - await called_fnc.task - except Exception as e: - logger.exception( - "error executing ai function", - extra={ - "function": fnc.function_info.name, - "speech_id": speech_handle.id, - }, - exc_info=e, - ) - - tool_calls_info = [] - tool_calls_results = [] - - for called_fnc in called_fncs: - # ignore the function calls that returns None - if called_fnc.result is None and called_fnc.exception is None: - continue - tool_calls_info.append(called_fnc.call_info) - tool_calls_results.append( - ChatMessage.create_tool_from_called_function(called_fnc) - ) + tool_calls_info = [] + tool_calls_results = [] - if not tool_calls_info: - break + for called_fnc in called_fncs: + # ignore the function calls that returns None + if called_fnc.result is None and called_fnc.exception is None: + continue - # generate an answer from the tool calls - extra_tools_messages.append( - ChatMessage.create_tool_calls(tool_calls_info, text=collected_text) + tool_calls_info.append(called_fnc.call_info) + tool_calls_results.append( + ChatMessage.create_tool_from_called_function(called_fnc) ) - extra_tools_messages.extend(tool_calls_results) - chat_ctx = speech_handle.source.chat_ctx.copy() - chat_ctx.messages.extend(extra_tools_messages) - - answer_llm_stream = self._llm.chat( - chat_ctx=chat_ctx, fnc_ctx=self.fnc_ctx - ) - answer_synthesis = self._synthesize_agent_speech( - speech_handle.id, answer_llm_stream - ) - # replace the synthesis handle with the new one to allow interruption - speech_handle.synthesis_handle = answer_synthesis - play_handle = answer_synthesis.play() - await play_handle.join() + if not tool_calls_info: + return - collected_text = answer_synthesis.tts_forwarder.played_text - interrupted = answer_synthesis.interrupted - new_function_calls = answer_llm_stream.function_calls + # create a nested speech handle + extra_tools_messages = [ + ChatMessage.create_tool_calls(tool_calls_info, text=collected_text) + ] + extra_tools_messages.extend(tool_calls_results) + + new_speech_handle = SpeechHandle.create_tool_speech( + allow_interruptions=speech_handle.allow_interruptions, + add_to_chat_ctx=speech_handle.add_to_chat_ctx, + extra_tools_messages=extra_tools_messages, + fnc_nested_depth=speech_handle.fnc_nested_depth + 1, + fnc_text_message_id=message_id_committed, + ) - self.emit("function_calls_finished", called_fncs) + # synthesize the tool speech with the chat ctx from llm_stream + chat_ctx = call_ctx.chat_ctx.copy() + chat_ctx.messages.extend(extra_tools_messages) + chat_ctx.messages.extend(call_ctx.extra_chat_messages) + fnc_ctx = self.fnc_ctx + if ( + fnc_ctx + and new_speech_handle.fnc_nested_depth + >= self._opts.max_nested_fnc_calls + ): + if len(fnc_ctx.ai_functions) > 1: + logger.info( + "max function calls nested depth reached, dropping function context. increase max_nested_fnc_calls to enable additional nesting.", + extra={ + "speech_id": speech_handle.id, + "fnc_nested_depth": speech_handle.fnc_nested_depth, + }, + ) + fnc_ctx = None + answer_llm_stream = self._llm.chat( + chat_ctx=chat_ctx, + fnc_ctx=fnc_ctx, + ) - if not new_function_calls: - break + synthesis_handle = self._synthesize_agent_speech( + new_speech_handle.id, answer_llm_stream + ) + new_speech_handle.initialize( + source=answer_llm_stream, synthesis_handle=synthesis_handle + ) + speech_handle.add_nested_speech(new_speech_handle) + self.emit("function_calls_finished", called_fncs) _CallContextVar.reset(tk) - if speech_handle.add_to_chat_ctx and ( - not user_question or speech_handle.user_committed - ): - self._chat_ctx.messages.extend(extra_tools_messages) + if not is_using_tools: + speech_handle._set_done() + return - if interrupted: - collected_text += "..." + fnc_task = asyncio.create_task(_execute_function_calls()) + while not speech_handle.nested_speech_done: + nesting_changed = asyncio.create_task( + speech_handle.nested_speech_changed.wait() + ) + nesting_done_fut: asyncio.Future = speech_handle._nested_speech_done_fut + await asyncio.wait( + [nesting_changed, fnc_task, nesting_done_fut], + return_when=asyncio.FIRST_COMPLETED, + ) + if not nesting_changed.done(): + nesting_changed.cancel() - msg = ChatMessage.create(text=collected_text, role="assistant") - self._chat_ctx.messages.append(msg) + while speech_handle.nested_speech_handles: + speech = speech_handle.nested_speech_handles[0] + if speech_handle.nested_speech_done: + # in case tool speech is added after nested speech done + speech.cancel(cancel_nested=True) + speech_handle.nested_speech_handles.pop(0) + continue - speech_handle.mark_speech_committed() + self._playing_speech = speech + await self._play_speech(speech) + speech_handle.nested_speech_handles.pop(0) + self._playing_speech = speech_handle - if interrupted: - self.emit("agent_speech_interrupted", msg) - else: - self.emit("agent_speech_committed", msg) + speech_handle.nested_speech_changed.clear() + # break if the function calls task is done + if fnc_task.done(): + speech_handle.mark_nested_speech_done() + if not fnc_task.done(): logger.debug( - "committed agent speech", - extra={ - "agent_transcript": collected_text, - "interrupted": interrupted, - "speech_id": speech_handle.id, - }, + "cancelling function calls task", extra={"speech_id": speech_handle.id} ) + fnc_task.cancel() + + # mark the speech as done + speech_handle._set_done() def _synthesize_agent_speech( self, @@ -926,7 +1107,7 @@ async def _llm_stream_to_str_generator( def _validate_reply_if_possible(self) -> None: """Check if the new agent speech should be played""" - if self._playing_speech is not None: + if self._playing_speech and not self._playing_speech.interrupted: should_ignore_input = False if not self._playing_speech.allow_interruptions: should_ignore_input = True @@ -940,19 +1121,24 @@ def _validate_reply_if_possible(self) -> None: "interrupt threshold is not met", extra={"speech_id": self._playing_speech.id}, ) + if should_ignore_input: self._transcribed_text = "" return if self._pending_agent_reply is None: - if self._opts.preemptive_synthesis or not self._transcribed_text: + if self._opts.preemptive_synthesis: return + # as long as we don't have a pending reply, we need to synthesize it + # in order to keep the conversation flowing. + # transcript could be empty at this moment, if the user interrupted the agent + # but did not generate any transcribed text. self._synthesize_agent_reply() assert self._pending_agent_reply is not None - # in some bad timing, we could end up with two pushed agent replies inside the speech queue. + # due to timing, we could end up with two pushed agent replies inside the speech queue. # so make sure we directly interrupt every reply when validating a new one for speech in self._speech_q: if not speech.is_reply: @@ -963,7 +1149,10 @@ def _validate_reply_if_possible(self) -> None: logger.debug( "validated agent reply", - extra={"speech_id": self._pending_agent_reply.id}, + extra={ + "speech_id": self._pending_agent_reply.id, + "text": self._transcribed_text, + }, ) if self._last_speech_time is not None: @@ -992,7 +1181,7 @@ def _interrupt_if_possible(self) -> None: def _should_interrupt(self) -> bool: if self._playing_speech is None: - return True + return False if ( not self._playing_speech.allow_interruptions @@ -1020,62 +1209,91 @@ class _DeferredReplyValidation: PUNCTUATION = ".!?" PUNCTUATION_REDUCE_FACTOR = 0.75 - LATE_TRANSCRIPT_TOLERANCE = 1.5 # late compared to end of speech + FINAL_TRANSCRIPT_TIMEOUT = 5 def __init__( self, validate_fnc: Callable[[], None], min_endpointing_delay: float, - loop: asyncio.AbstractEventLoop | None = None, + max_endpointing_delay: float, + turn_detector: _TurnDetector | None, + agent: VoicePipelineAgent, ) -> None: + self._turn_detector = turn_detector self._validate_fnc = validate_fnc self._validating_task: asyncio.Task | None = None self._last_final_transcript: str = "" + self._last_language: str | None = None + self._last_recv_start_of_speech_time: float = 0.0 self._last_recv_end_of_speech_time: float = 0.0 + self._last_recv_transcript_time: float = 0.0 self._speaking = False + self._agent = agent self._end_of_speech_delay = min_endpointing_delay - self._final_transcript_delay = min_endpointing_delay + 1.0 + self._max_endpointing_delay = max_endpointing_delay @property def validating(self) -> bool: return self._validating_task is not None and not self._validating_task.done() - def on_human_final_transcript(self, transcript: str) -> None: - self._last_final_transcript = transcript.strip() # type: ignore + def _compute_delay(self) -> float | None: + """Computes the amount of time to wait before validating the agent reply. + This function should be called after the agent has received final transcript, or after VAD + """ + # never interrupt the user while they are speaking if self._speaking: - return + return None - has_recent_end_of_speech = ( - time.time() - self._last_recv_end_of_speech_time - < self.LATE_TRANSCRIPT_TOLERANCE - ) - delay = ( - self._end_of_speech_delay - if has_recent_end_of_speech - else self._final_transcript_delay - ) - delay = delay * ( - self.PUNCTUATION_REDUCE_FACTOR if self._end_with_punctuation() else 1.0 - ) + # if STT doesn't give us the final transcript after end of speech, we'll still validate the reply + # to prevent the agent from getting "stuck" + # in this case, the agent will not have final transcript, so it'll trigger the user input with empty + if not self._last_final_transcript: + return self.FINAL_TRANSCRIPT_TIMEOUT + + delay = self._end_of_speech_delay + if self._end_with_punctuation(): + delay = delay * self.PUNCTUATION_REDUCE_FACTOR + + # the delay should be computed from end of earlier timestamp, that's the true end of user speech + end_of_speech_time = self._last_recv_end_of_speech_time + if ( + self._last_recv_transcript_time > 0 + and self._last_recv_transcript_time > self._last_recv_start_of_speech_time + and self._last_recv_transcript_time < end_of_speech_time + ): + end_of_speech_time = self._last_recv_transcript_time + + elapsed_time = time.perf_counter() - end_of_speech_time + if elapsed_time < delay: + delay -= elapsed_time + else: + delay = 0 + return delay - self._run(delay) + def on_human_final_transcript(self, transcript: str, language: str | None) -> None: + self._last_final_transcript += " " + transcript.strip() # type: ignore + self._last_language = language + self._last_recv_transcript_time = time.perf_counter() + + delay = self._compute_delay() + if delay is not None: + self._run(delay) def on_human_start_of_speech(self, ev: VADEvent) -> None: self._speaking = True + self._last_recv_start_of_speech_time = time.perf_counter() if self.validating: assert self._validating_task is not None self._validating_task.cancel() def on_human_end_of_speech(self, ev: VADEvent) -> None: self._speaking = False - self._last_recv_end_of_speech_time = time.time() + self._last_recv_end_of_speech_time = time.perf_counter() - if self._last_final_transcript: - delay = self._end_of_speech_delay * ( - self.PUNCTUATION_REDUCE_FACTOR if self._end_with_punctuation() else 1.0 - ) + delay = self._compute_delay() + if delay is not None: self._run(delay) async def aclose(self) -> None: @@ -1091,15 +1309,34 @@ def _end_with_punctuation(self) -> bool: def _reset_states(self) -> None: self._last_final_transcript = "" self._last_recv_end_of_speech_time = 0.0 + self._last_recv_transcript_time = 0.0 def _run(self, delay: float) -> None: @utils.log_exceptions(logger=logger) - async def _run_task(delay: float) -> None: + async def _run_task(chat_ctx: ChatContext, delay: float) -> None: + use_turn_detector = self._last_final_transcript and not self._speaking + if ( + use_turn_detector + and self._turn_detector is not None + and self._turn_detector.supports_language(self._last_language) + ): + start_time = time.perf_counter() + eot_prob = await self._turn_detector.predict_end_of_turn(chat_ctx) + unlikely_threshold = self._turn_detector.unlikely_threshold() + elasped = time.perf_counter() - start_time + if eot_prob < unlikely_threshold: + delay = self._max_endpointing_delay + delay = max(0, delay - elasped) await asyncio.sleep(delay) + self._reset_states() self._validate_fnc() if self._validating_task is not None: self._validating_task.cancel() - self._validating_task = asyncio.create_task(_run_task(delay)) + detect_ctx = self._agent._chat_ctx.copy() + detect_ctx.messages.append( + ChatMessage.create(text=self._agent._transcribed_text, role="user") + ) + self._validating_task = asyncio.create_task(_run_task(detect_ctx, delay)) diff --git a/livekit-agents/livekit/agents/pipeline/speech_handle.py b/livekit-agents/livekit/agents/pipeline/speech_handle.py index a0f0c7d93..cd1f39dec 100644 --- a/livekit-agents/livekit/agents/pipeline/speech_handle.py +++ b/livekit-agents/livekit/agents/pipeline/speech_handle.py @@ -4,7 +4,7 @@ from typing import AsyncIterable from .. import utils -from ..llm import LLMStream +from ..llm import ChatMessage, LLMStream from .agent_output import SynthesisHandle @@ -17,6 +17,9 @@ def __init__( add_to_chat_ctx: bool, is_reply: bool, user_question: str, + fnc_nested_depth: int = 0, + extra_tools_messages: list[ChatMessage] | None = None, + fnc_text_message_id: str | None = None, ) -> None: self._id = id self._allow_interruptions = allow_interruptions @@ -27,7 +30,8 @@ def __init__( self._user_question = user_question self._user_committed = False - self._init_fut: asyncio.Future[None] = asyncio.Future() + self._init_fut = asyncio.Future[None]() + self._done_fut = asyncio.Future[None]() self._initialized = False self._speech_committed = False # speech committed (interrupted or not) @@ -35,6 +39,15 @@ def __init__( self._source: str | LLMStream | AsyncIterable[str] | None = None self._synthesis_handle: SynthesisHandle | None = None + # nested speech handle and function calls + self._fnc_nested_depth = fnc_nested_depth + self._fnc_extra_tools_messages: list[ChatMessage] | None = extra_tools_messages + self._fnc_text_message_id: str | None = fnc_text_message_id + + self._nested_speech_handles: list[SpeechHandle] = [] + self._nested_speech_changed = asyncio.Event() + self._nested_speech_done_fut = asyncio.Future[None]() + @staticmethod def create_assistant_reply( *, @@ -64,6 +77,26 @@ def create_assistant_speech( user_question="", ) + @staticmethod + def create_tool_speech( + *, + allow_interruptions: bool, + add_to_chat_ctx: bool, + fnc_nested_depth: int, + extra_tools_messages: list[ChatMessage], + fnc_text_message_id: str | None = None, + ) -> SpeechHandle: + return SpeechHandle( + id=utils.shortuuid(), + allow_interruptions=allow_interruptions, + add_to_chat_ctx=add_to_chat_ctx, + is_reply=False, + user_question="", + fnc_nested_depth=fnc_nested_depth, + extra_tools_messages=extra_tools_messages, + fnc_text_message_id=fnc_text_message_id, + ) + async def wait_for_initialization(self) -> None: await asyncio.shield(self._init_fut) @@ -146,13 +179,57 @@ def interrupted(self) -> bool: self._synthesis_handle is not None and self._synthesis_handle.interrupted ) + def join(self) -> asyncio.Future: + return self._done_fut + + def _set_done(self) -> None: + self._done_fut.set_result(None) + def interrupt(self) -> None: if not self.allow_interruptions: raise RuntimeError("interruptions are not allowed") self.cancel() - def cancel(self) -> None: + def cancel(self, cancel_nested: bool = False) -> None: self._init_fut.cancel() if self._synthesis_handle is not None: self._synthesis_handle.interrupt() + + if cancel_nested: + for speech in self._nested_speech_handles: + speech.cancel(cancel_nested=True) + self.mark_nested_speech_done() + + @property + def fnc_nested_depth(self) -> int: + return self._fnc_nested_depth + + @property + def extra_tools_messages(self) -> list[ChatMessage] | None: + return self._fnc_extra_tools_messages + + @property + def fnc_text_message_id(self) -> str | None: + return self._fnc_text_message_id + + def add_nested_speech(self, speech_handle: SpeechHandle) -> None: + self._nested_speech_handles.append(speech_handle) + self._nested_speech_changed.set() + + @property + def nested_speech_handles(self) -> list[SpeechHandle]: + return self._nested_speech_handles + + @property + def nested_speech_changed(self) -> asyncio.Event: + return self._nested_speech_changed + + @property + def nested_speech_done(self) -> bool: + return self._nested_speech_done_fut.done() + + def mark_nested_speech_done(self) -> None: + if self._nested_speech_done_fut.done(): + return + self._nested_speech_done_fut.set_result(None) diff --git a/livekit-agents/livekit/agents/plugin.py b/livekit-agents/livekit/agents/plugin.py index 3554fc337..5aca08a93 100644 --- a/livekit-agents/livekit/agents/plugin.py +++ b/livekit-agents/livekit/agents/plugin.py @@ -13,7 +13,6 @@ class Plugin(ABC): registered_plugins: List["Plugin"] = [] emitter: utils.EventEmitter[EventTypes] = utils.EventEmitter() - lock = threading.Lock() # TODO(theomonnom): make logger mandatory once all plugins have been updated def __init__( diff --git a/livekit-agents/livekit/agents/stt/__init__.py b/livekit-agents/livekit/agents/stt/__init__.py index 3b1fb146c..fc8f99044 100644 --- a/livekit-agents/livekit/agents/stt/__init__.py +++ b/livekit-agents/livekit/agents/stt/__init__.py @@ -1,7 +1,9 @@ +from .fallback_adapter import AvailabilityChangedEvent, FallbackAdapter from .stream_adapter import StreamAdapter, StreamAdapterWrapper from .stt import ( STT, RecognitionUsage, + RecognizeStream, SpeechData, SpeechEvent, SpeechEventType, @@ -13,10 +15,13 @@ "SpeechEventType", "SpeechEvent", "SpeechData", + "RecognizeStream", "SpeechStream", "STT", "STTCapabilities", "StreamAdapter", "StreamAdapterWrapper", "RecognitionUsage", + "FallbackAdapter", + "AvailabilityChangedEvent", ] diff --git a/livekit-agents/livekit/agents/stt/fallback_adapter.py b/livekit-agents/livekit/agents/stt/fallback_adapter.py new file mode 100644 index 000000000..ac11a76db --- /dev/null +++ b/livekit-agents/livekit/agents/stt/fallback_adapter.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +import asyncio +import contextlib +import dataclasses +import time +from dataclasses import dataclass +from typing import Literal + +from livekit import rtc +from livekit.agents.utils.audio import AudioBuffer + +from .. import utils +from .._exceptions import APIConnectionError, APIError +from ..log import logger +from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions +from ..utils import aio +from .stt import STT, RecognizeStream, SpeechEvent, SpeechEventType, STTCapabilities + +# don't retry when using the fallback adapter +DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions( + max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout +) + + +@dataclass +class AvailabilityChangedEvent: + stt: STT + available: bool + + +@dataclass +class _STTStatus: + available: bool + recovering_synthesize_task: asyncio.Task | None + recovering_stream_task: asyncio.Task | None + + +class FallbackAdapter( + STT[Literal["stt_availability_changed"]], +): + def __init__( + self, + stt: list[STT], + *, + attempt_timeout: float = 10.0, + max_retry_per_stt: int = 1, + retry_interval: float = 5, + ) -> None: + if len(stt) < 1: + raise ValueError("At least one STT instance must be provided.") + + super().__init__( + capabilities=STTCapabilities( + streaming=all(t.capabilities.streaming for t in stt), + interim_results=all(t.capabilities.interim_results for t in stt), + ) + ) + + self._stt_instances = stt + self._attempt_timeout = attempt_timeout + self._max_retry_per_stt = max_retry_per_stt + self._retry_interval = retry_interval + + self._status: list[_STTStatus] = [ + _STTStatus( + available=True, + recovering_synthesize_task=None, + recovering_stream_task=None, + ) + for _ in self._stt_instances + ] + + async def _try_recognize( + self, + *, + stt: STT, + buffer: utils.AudioBuffer, + language: str | None = None, + conn_options: APIConnectOptions, + recovering: bool = False, + ) -> SpeechEvent: + try: + return await stt.recognize( + buffer, + language=language, + conn_options=dataclasses.replace( + conn_options, + max_retry=self._max_retry_per_stt, + timeout=self._attempt_timeout, + retry_interval=self._retry_interval, + ), + ) + except asyncio.TimeoutError: + if recovering: + logger.warning( + f"{stt.label} recovery timed out", extra={"streamed": False} + ) + raise + + logger.warning( + f"{stt.label} timed out, switching to next STT", + extra={"streamed": False}, + ) + + raise + except APIError as e: + if recovering: + logger.warning( + f"{stt.label} recovery failed", + exc_info=e, + extra={"streamed": False}, + ) + raise + + logger.warning( + f"{stt.label} failed, switching to next STT", + exc_info=e, + extra={"streamed": False}, + ) + raise + except Exception: + if recovering: + logger.exception( + f"{stt.label} recovery unexpected error", extra={"streamed": False} + ) + raise + + logger.exception( + f"{stt.label} unexpected error, switching to next STT", + extra={"streamed": False}, + ) + raise + + def _try_recovery( + self, + *, + stt: STT, + buffer: utils.AudioBuffer, + language: str | None, + conn_options: APIConnectOptions, + ) -> None: + stt_status = self._status[self._stt_instances.index(stt)] + if ( + stt_status.recovering_synthesize_task is None + or stt_status.recovering_synthesize_task.done() + ): + + async def _recover_stt_task(stt: STT) -> None: + try: + await self._try_recognize( + stt=stt, + buffer=buffer, + language=language, + conn_options=conn_options, + recovering=True, + ) + + stt_status.available = True + logger.info(f"{stt.label} recovered") + self.emit( + "stt_availability_changed", + AvailabilityChangedEvent(stt=stt, available=True), + ) + except Exception: + return + + stt_status.recovering_synthesize_task = asyncio.create_task( + _recover_stt_task(stt) + ) + + async def _recognize_impl( + self, + buffer: utils.AudioBuffer, + *, + language: str | None, + conn_options: APIConnectOptions, + ): + start_time = time.time() + + all_failed = all(not stt_status.available for stt_status in self._status) + if all_failed: + logger.error("all STTs are unavailable, retrying..") + + for i, stt in enumerate(self._stt_instances): + stt_status = self._status[i] + if stt_status.available or all_failed: + try: + return await self._try_recognize( + stt=stt, + buffer=buffer, + language=language, + conn_options=conn_options, + recovering=False, + ) + except Exception: # exceptions already logged inside _try_recognize + if stt_status.available: + stt_status.available = False + self.emit( + "stt_availability_changed", + AvailabilityChangedEvent(stt=stt, available=False), + ) + + self._try_recovery( + stt=stt, buffer=buffer, language=language, conn_options=conn_options + ) + + raise APIConnectionError( + "all STTs failed (%s) after %s seconds" + % ( + [stt.label for stt in self._stt_instances], + time.time() - start_time, + ) + ) + + async def recognize( + self, + buffer: AudioBuffer, + *, + language: str | None = None, + conn_options: APIConnectOptions = DEFAULT_FALLBACK_API_CONNECT_OPTIONS, + ) -> SpeechEvent: + return await super().recognize( + buffer, language=language, conn_options=conn_options + ) + + def stream( + self, + *, + language: str | None = None, + conn_options: APIConnectOptions = DEFAULT_FALLBACK_API_CONNECT_OPTIONS, + ) -> RecognizeStream: + return FallbackRecognizeStream( + stt=self, language=language, conn_options=conn_options + ) + + async def aclose(self) -> None: + for stt_status in self._status: + if stt_status.recovering_synthesize_task is not None: + await aio.gracefully_cancel(stt_status.recovering_synthesize_task) + + if stt_status.recovering_stream_task is not None: + await aio.gracefully_cancel(stt_status.recovering_stream_task) + + +class FallbackRecognizeStream(RecognizeStream): + def __init__( + self, + *, + stt: FallbackAdapter, + language: str | None, + conn_options: APIConnectOptions, + ): + super().__init__(stt=stt, conn_options=conn_options, sample_rate=None) + self._language = language + self._fallback_adapter = stt + self._recovering_streams: list[RecognizeStream] = [] + + async def _run(self) -> None: + start_time = time.time() + + all_failed = all( + not stt_status.available for stt_status in self._fallback_adapter._status + ) + if all_failed: + logger.error("all STTs are unavailable, retrying..") + + main_stream: RecognizeStream | None = None + forward_input_task: asyncio.Task | None = None + + async def _forward_input_task() -> None: + with contextlib.suppress(RuntimeError): # stream might be closed + async for data in self._input_ch: + for stream in self._recovering_streams: + if isinstance(data, rtc.AudioFrame): + stream.push_frame(data) + elif isinstance(data, self._FlushSentinel): + stream.flush() + + if main_stream is not None: + if isinstance(data, rtc.AudioFrame): + main_stream.push_frame(data) + elif isinstance(data, self._FlushSentinel): + main_stream.flush() + + if main_stream is not None: + main_stream.end_input() + + for i, stt in enumerate(self._fallback_adapter._stt_instances): + stt_status = self._fallback_adapter._status[i] + if stt_status.available or all_failed: + try: + main_stream = stt.stream( + language=self._language, + conn_options=dataclasses.replace( + self._conn_options, + max_retry=self._fallback_adapter._max_retry_per_stt, + timeout=self._fallback_adapter._attempt_timeout, + retry_interval=self._fallback_adapter._retry_interval, + ), + ) + + if forward_input_task is None or forward_input_task.done(): + forward_input_task = asyncio.create_task(_forward_input_task()) + + try: + async with main_stream: + async for ev in main_stream: + self._event_ch.send_nowait(ev) + + except asyncio.TimeoutError: + logger.warning( + f"{stt.label} timed out, switching to next STT", + extra={"streamed": True}, + ) + raise + except APIError as e: + logger.warning( + f"{stt.label} failed, switching to next STT", + exc_info=e, + extra={"streamed": True}, + ) + raise + except Exception: + logger.exception( + f"{stt.label} unexpected error, switching to next STT", + extra={"streamed": True}, + ) + raise + + return + except Exception: + if stt_status.available: + stt_status.available = False + self._stt.emit( + "stt_availability_changed", + AvailabilityChangedEvent(stt=stt, available=False), + ) + + self._try_recovery(stt) + + if forward_input_task is not None: + await aio.gracefully_cancel(forward_input_task) + + await asyncio.gather(*[stream.aclose() for stream in self._recovering_streams]) + + raise APIConnectionError( + "all STTs failed (%s) after %s seconds" + % ( + [stt.label for stt in self._fallback_adapter._stt_instances], + time.time() - start_time, + ) + ) + + def _try_recovery(self, stt: STT) -> None: + stt_status = self._fallback_adapter._status[ + self._fallback_adapter._stt_instances.index(stt) + ] + if ( + stt_status.recovering_stream_task is None + or stt_status.recovering_stream_task.done() + ): + stream = stt.stream( + language=self._language, + conn_options=dataclasses.replace( + self._conn_options, + max_retry=0, + timeout=self._fallback_adapter._attempt_timeout, + ), + ) + self._recovering_streams.append(stream) + + async def _recover_stt_task() -> None: + try: + nb_transcript = 0 + async with stream: + async for ev in stream: + if ev.type in SpeechEventType.FINAL_TRANSCRIPT: + if not ev.alternatives or not ev.alternatives[0].text: + continue + + nb_transcript += 1 + break + + if nb_transcript == 0: + return + + stt_status.available = True + logger.info(f"tts.FallbackAdapter, {stt.label} recovered") + self._fallback_adapter.emit( + "stt_availability_changed", + AvailabilityChangedEvent(stt=stt, available=True), + ) + + except asyncio.TimeoutError: + logger.warning( + f"{stream._stt.label} recovery timed out", + extra={"streamed": True}, + ) + except APIError as e: + logger.warning( + f"{stream._stt.label} recovery failed", + exc_info=e, + extra={"streamed": True}, + ) + except Exception: + logger.exception( + f"{stream._stt.label} recovery unexpected error", + extra={"streamed": True}, + ) + raise + + stt_status.recovering_stream_task = task = asyncio.create_task( + _recover_stt_task() + ) + task.add_done_callback(lambda _: self._recovering_streams.remove(stream)) diff --git a/livekit-agents/livekit/agents/stt/stream_adapter.py b/livekit-agents/livekit/agents/stt/stream_adapter.py index 39745d640..0e69d65c5 100644 --- a/livekit-agents/livekit/agents/stt/stream_adapter.py +++ b/livekit-agents/livekit/agents/stt/stream_adapter.py @@ -4,9 +4,9 @@ from typing import AsyncIterable from .. import utils -from ..log import logger +from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from ..vad import VAD, VADEventType -from .stt import STT, SpeechEvent, SpeechEventType, SpeechStream, STTCapabilities +from .stt import STT, RecognizeStream, SpeechEvent, SpeechEventType, STTCapabilities class StreamAdapter(STT): @@ -26,21 +26,42 @@ def wrapped_stt(self) -> STT: return self._stt async def _recognize_impl( - self, buffer: utils.AudioBuffer, *, language: str | None = None + self, + buffer: utils.AudioBuffer, + *, + language: str | None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ): - return await self._stt.recognize(buffer=buffer, language=language) + return await self._stt.recognize( + buffer=buffer, language=language, conn_options=conn_options + ) - def stream(self, *, language: str | None = None) -> SpeechStream: + def stream( + self, + *, + language: str | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> RecognizeStream: return StreamAdapterWrapper( - self, vad=self._vad, wrapped_stt=self._stt, language=language + self, + vad=self._vad, + wrapped_stt=self._stt, + language=language, + conn_options=conn_options, ) -class StreamAdapterWrapper(SpeechStream): +class StreamAdapterWrapper(RecognizeStream): def __init__( - self, stt: STT, *, vad: VAD, wrapped_stt: STT, language: str | None + self, + stt: STT, + *, + vad: VAD, + wrapped_stt: STT, + language: str | None, + conn_options: APIConnectOptions, ) -> None: - super().__init__(stt) + super().__init__(stt=stt, conn_options=conn_options) self._vad = vad self._wrapped_stt = wrapped_stt self._vad_stream = self._vad.stream() @@ -51,8 +72,7 @@ async def _metrics_monitor_task( ) -> None: pass # do nothing - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: + async def _run(self) -> None: async def _forward_input(): """forward input to vad""" async for input in self._input_ch: @@ -79,7 +99,9 @@ async def _recognize(): merged_frames = utils.merge_frames(event.frames) t_event = await self._wrapped_stt.recognize( - buffer=merged_frames, language=self._language + buffer=merged_frames, + language=self._language, + conn_options=self._conn_options, ) if len(t_event.alternatives) == 0: diff --git a/livekit-agents/livekit/agents/stt/stt.py b/livekit-agents/livekit/agents/stt/stt.py index 399089d1c..e2f79f93c 100644 --- a/livekit-agents/livekit/agents/stt/stt.py +++ b/livekit-agents/livekit/agents/stt/stt.py @@ -6,11 +6,14 @@ from dataclasses import dataclass, field from enum import Enum, unique from types import TracebackType -from typing import AsyncIterable, AsyncIterator, List, Literal, Union +from typing import AsyncIterable, AsyncIterator, Generic, List, Literal, TypeVar, Union from livekit import rtc +from .._exceptions import APIConnectionError, APIError +from ..log import logger from ..metrics import STTMetrics +from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from ..utils import AudioBuffer, aio from ..utils.audio import calculate_audio_duration @@ -59,40 +62,90 @@ class STTCapabilities: interim_results: bool -class STT(ABC, rtc.EventEmitter[Literal["metrics_collected"]]): +TEvent = TypeVar("TEvent") + + +class STT( + ABC, + rtc.EventEmitter[Union[Literal["metrics_collected"], TEvent]], + Generic[TEvent], +): def __init__(self, *, capabilities: STTCapabilities) -> None: super().__init__() self._capabilities = capabilities self._label = f"{type(self).__module__}.{type(self).__name__}" + @property + def label(self) -> str: + return self._label + @property def capabilities(self) -> STTCapabilities: return self._capabilities @abstractmethod async def _recognize_impl( - self, buffer: AudioBuffer, *, language: str | None = None + self, + buffer: AudioBuffer, + *, + language: str | None, + conn_options: APIConnectOptions, ) -> SpeechEvent: ... async def recognize( - self, buffer: AudioBuffer, *, language: str | None = None + self, + buffer: AudioBuffer, + *, + language: str | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> SpeechEvent: - start_time = time.perf_counter() - event = await self._recognize_impl(buffer, language=language) - duration = time.perf_counter() - start_time - stt_metrics = STTMetrics( - request_id=event.request_id, - timestamp=time.time(), - duration=duration, - label=self._label, - audio_duration=calculate_audio_duration(buffer), - streamed=False, - error=None, - ) - self.emit("metrics_collected", stt_metrics) - return event - - def stream(self, *, language: str | None = None) -> "SpeechStream": + for i in range(conn_options.max_retry + 1): + try: + start_time = time.perf_counter() + event = await self._recognize_impl( + buffer, language=language, conn_options=conn_options + ) + duration = time.perf_counter() - start_time + stt_metrics = STTMetrics( + request_id=event.request_id, + timestamp=time.time(), + duration=duration, + label=self._label, + audio_duration=calculate_audio_duration(buffer), + streamed=False, + error=None, + ) + self.emit("metrics_collected", stt_metrics) + return event + + except APIError as e: + if conn_options.max_retry == 0: + raise + elif i == conn_options.max_retry: + raise APIConnectionError( + f"failed to recognize speech after {conn_options.max_retry + 1} attempts", + ) from e + else: + logger.warning( + f"failed to recognize speech, retrying in {conn_options.retry_interval}s", + exc_info=e, + extra={ + "tts": self._label, + "attempt": i + 1, + "streamed": False, + }, + ) + + await asyncio.sleep(conn_options.retry_interval) + + raise RuntimeError("unreachable") + + def stream( + self, + *, + language: str | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> "RecognizeStream": raise NotImplementedError( "streaming is not supported by this STT, please use a different STT or use a StreamAdapter" ) @@ -113,13 +166,19 @@ async def __aexit__( await self.aclose() -class SpeechStream(ABC): +class RecognizeStream(ABC): class _FlushSentinel: """Sentinel to mark when it was flushed""" pass - def __init__(self, stt: STT, *, sample_rate: int | None = None): + def __init__( + self, + *, + stt: STT, + conn_options: APIConnectOptions, + sample_rate: int | None = None, + ): """ Args: sample_rate : int or None, optional @@ -129,7 +188,10 @@ def __init__(self, stt: STT, *, sample_rate: int | None = None): If not provided (None), the input will retain its original sample rate. """ self._stt = stt - self._input_ch = aio.Chan[Union[rtc.AudioFrame, SpeechStream._FlushSentinel]]() + self._conn_options = conn_options + self._input_ch = aio.Chan[ + Union[rtc.AudioFrame, RecognizeStream._FlushSentinel] + ]() self._event_ch = aio.Chan[SpeechEvent]() self._event_aiter, monitor_aiter = aio.itertools.tee(self._event_ch, 2) @@ -145,7 +207,31 @@ def __init__(self, stt: STT, *, sample_rate: int | None = None): self._resampler: rtc.AudioResampler | None = None @abstractmethod - async def _main_task(self) -> None: ... + async def _run(self) -> None: ... + + async def _main_task(self) -> None: + for i in range(self._conn_options.max_retry + 1): + try: + return await self._run() + except APIError as e: + if self._conn_options.max_retry == 0: + raise + elif i == self._conn_options.max_retry: + raise APIConnectionError( + f"failed to recognize speech after {self._conn_options.max_retry + 1} attempts", + ) from e + else: + logger.warning( + f"failed to recognize speech, retrying in {self._conn_options.retry_interval}s", + exc_info=e, + extra={ + "tts": self._stt._label, + "attempt": i + 1, + "streamed": True, + }, + ) + + await asyncio.sleep(self._conn_options.retry_interval) async def _metrics_monitor_task( self, event_aiter: AsyncIterable[SpeechEvent] @@ -209,7 +295,7 @@ def flush(self) -> None: self._input_ch.send_nowait(self._FlushSentinel()) def end_input(self) -> None: - """Mark the end of input, no more text will be pushed""" + """Mark the end of input, no more audio will be pushed""" self.flush() self._input_ch.close() @@ -244,3 +330,17 @@ def _check_input_not_ended(self) -> None: if self._input_ch.closed: cls = type(self) raise RuntimeError(f"{cls.__module__}.{cls.__name__} input ended") + + async def __aenter__(self) -> RecognizeStream: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.aclose() + + +SpeechStream = RecognizeStream # deprecated alias diff --git a/livekit-agents/livekit/agents/transcription/_utils.py b/livekit-agents/livekit/agents/transcription/_utils.py index dc839f2e6..4e24960dd 100644 --- a/livekit-agents/livekit/agents/transcription/_utils.py +++ b/livekit-agents/livekit/agents/transcription/_utils.py @@ -1,9 +1,9 @@ from __future__ import annotations -import uuid - from livekit import rtc +from ..utils import shortuuid + def find_micro_track_id(room: rtc.Room, identity: str) -> str: p: rtc.RemoteParticipant | rtc.LocalParticipant | None = ( @@ -29,4 +29,4 @@ def find_micro_track_id(room: rtc.Room, identity: str) -> str: def segment_uuid() -> str: - return "SG_" + str(uuid.uuid4().hex)[:12] + return shortuuid("SG_") diff --git a/livekit-agents/livekit/agents/tts/fallback_adapter.py b/livekit-agents/livekit/agents/tts/fallback_adapter.py index 0bcbf7df8..d990d5934 100644 --- a/livekit-agents/livekit/agents/tts/fallback_adapter.py +++ b/livekit-agents/livekit/agents/tts/fallback_adapter.py @@ -55,7 +55,7 @@ def __init__( *, attempt_timeout: float = 10.0, max_retry_per_tts: int = 1, # only retry once by default - retry_interval: float = 0.5, + retry_interval: float = 5, no_fallback_after_audio_duration: float | None = 3.0, sample_rate: int | None = None, ) -> None: @@ -67,6 +67,8 @@ def __init__( attempt_timeout (float, optional): Timeout for each synthesis attempt in seconds. Defaults to 10.0. max_retry_per_tts (int, optional): Maximum number of retries per TTS instance. Defaults to 1. no_fallback_after_audio_duration (float | None, optional): Disables fallback after this duration of audio is synthesized. Defaults to 3.0. + This is used to prevent unnaturally resaying the same text when the first TTS + instance fails. sample_rate (int | None, optional): Desired sample rate for the synthesized audio. If None, uses the maximum sample rate among the TTS instances. Raises: @@ -75,7 +77,7 @@ def __init__( """ if len(tts) < 1: - raise ValueError("At least one TTS instance must be provided.") + raise ValueError("at least one TTS instance must be provided.") if len(set(t.num_channels for t in tts)) != 1: raise ValueError("all TTS must have the same number of channels") @@ -93,7 +95,7 @@ def __init__( num_channels=num_channels, ) - self._wrapped_tts = tts + self._tts_instances = tts self._attempt_timeout = attempt_timeout self._max_retry_per_tts = max_retry_per_tts self._retry_interval = retry_interval @@ -144,31 +146,30 @@ async def aclose(self) -> None: class FallbackChunkedStream(ChunkedStream): def __init__( - self, *, tts: TTS, input_text: str, conn_options: APIConnectOptions + self, *, tts: FallbackAdapter, input_text: str, conn_options: APIConnectOptions ) -> None: super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) + self._fallback_adapter = tts async def _try_synthesize( self, *, tts: TTS, recovering: bool = False ) -> AsyncGenerator[SynthesizedAudio, None]: - assert isinstance(self._tts, FallbackAdapter) - try: audio_duration = 0.0 async with tts.synthesize( self._input_text, conn_options=dataclasses.replace( self._conn_options, - max_retry=self._tts._max_retry_per_tts, - timeout=self._tts._attempt_timeout, - retry_interval=self._tts._retry_interval, + max_retry=self._fallback_adapter._max_retry_per_tts, + timeout=self._fallback_adapter._attempt_timeout, + retry_interval=self._fallback_adapter._retry_interval, ), ) as stream: while True: try: audio = await asyncio.wait_for( stream.__anext__(), - self._tts._attempt_timeout + self._fallback_adapter._attempt_timeout if audio_duration == 0.0 else None, ) @@ -225,7 +226,7 @@ async def _try_synthesize( def _try_recovery(self, tts: TTS) -> None: assert isinstance(self._tts, FallbackAdapter) - tts_status = self._tts._status[self._tts._wrapped_tts.index(tts)] + tts_status = self._tts._status[self._tts._tts_instances.index(tts)] if tts_status.recovering_task is None or tts_status.recovering_task.done(): async def _recover_tts_task(tts: TTS) -> None: @@ -253,7 +254,7 @@ async def _run(self) -> None: if all_failed: logger.error("all TTSs are unavailable, retrying..") - for i, tts in enumerate(self._tts._wrapped_tts): + for i, tts in enumerate(self._tts._tts_instances): tts_status = self._tts._status[i] if tts_status.available or all_failed: audio_duration = 0.0 @@ -312,7 +313,7 @@ async def _run(self) -> None: raise APIConnectionError( "all TTSs failed (%s) after %s seconds" % ( - [tts.label for tts in self._tts._wrapped_tts], + [tts.label for tts in self._tts._tts_instances], time.time() - start_time, ) ) @@ -322,58 +323,55 @@ class FallbackSynthesizeStream(SynthesizeStream): def __init__( self, *, - tts: TTS, + tts: FallbackAdapter, conn_options: APIConnectOptions, ): super().__init__(tts=tts, conn_options=conn_options) + self._fallback_adapter = tts self._total_segments: list[list[str]] = [] - self._fallback_pending_texts: list[list[str]] = [] - self._fallback_text: list[str] = [] + self._pending_segments_chunks: list[list[str]] = [] + self._current_segment_text: list[str] = [] async def _try_synthesize( self, + *, tts: TTS, input_ch: aio.ChanReceiver[str | SynthesizeStream._FlushSentinel], + conn_options: APIConnectOptions, recovering: bool = False, ) -> AsyncGenerator[SynthesizedAudio, None]: - assert isinstance(self._tts, FallbackAdapter) - - stream = tts.stream( - conn_options=dataclasses.replace( - self._conn_options, - max_retry=self._tts._max_retry_per_tts, - timeout=self._tts._attempt_timeout, - retry_interval=self._tts._retry_interval, - ) - ) - + stream = tts.stream(conn_options=conn_options) input_sent_fut = asyncio.Future() # type: ignore + @utils.log_exceptions(logger=logger) async def _input_task() -> None: try: + segment = "" async for data in input_ch: if isinstance(data, str): - if data: - with contextlib.suppress(asyncio.InvalidStateError): - input_sent_fut.set_result(None) - + segment += data stream.push_text(data) elif isinstance(data, self._FlushSentinel): + # start the timeout on flush + if segment: + segment = "" + with contextlib.suppress(asyncio.InvalidStateError): + input_sent_fut.set_result(True) + stream.flush() finally: with contextlib.suppress(RuntimeError): stream.end_input() with contextlib.suppress(asyncio.InvalidStateError): - input_sent_fut.set_result(None) + input_sent_fut.set_result(False) input_task = asyncio.create_task(_input_task()) - next_audio_task: asyncio.Future | None = None + next_audio_task: asyncio.Future[SynthesizedAudio] | None = None try: audio_duration = 0.0 - async with stream: while True: if next_audio_task is None or next_audio_task.done(): @@ -392,14 +390,25 @@ async def _input_task() -> None: audio = next_audio_task.result() else: audio = await asyncio.wait_for( - next_audio_task, self._tts._attempt_timeout + next_audio_task, self._fallback_adapter._attempt_timeout ) audio_duration += audio.frame.duration + if audio.is_final: + input_sent_fut = asyncio.Future() + audio_duration = 0.0 + yield audio except StopAsyncIteration: break + if ( + audio_duration == 0.0 + and input_sent_fut.done() + and input_sent_fut.result() + ): + raise APIConnectionError("no audio received") + except asyncio.TimeoutError: if recovering: logger.warning( @@ -445,11 +454,11 @@ async def _input_task() -> None: await utils.aio.gracefully_cancel(input_task) async def _run(self) -> None: - assert isinstance(self._tts, FallbackAdapter) - start_time = time.time() - all_failed = all(not tts_status.available for tts_status in self._tts._status) + all_failed = all( + not tts_status.available for tts_status in self._fallback_adapter._status + ) if all_failed: logger.error("all TTSs are unavailable, retrying..") @@ -463,12 +472,14 @@ async def _forward_input_task(): new_input_ch.send_nowait(data) if isinstance(data, str) and data: - self._fallback_text.append(data) + self._current_segment_text.append(data) - elif isinstance(data, self._FlushSentinel) and self._fallback_text: - self._total_segments.append(self._fallback_text) - self._fallback_pending_texts.append(self._fallback_text) - self._fallback_text = [] + elif ( + isinstance(data, self._FlushSentinel) and self._current_segment_text + ): + self._total_segments.append(self._current_segment_text) + self._pending_segments_chunks.append(self._current_segment_text) + self._current_segment_text = [] if new_input_ch: new_input_ch.close() @@ -476,8 +487,8 @@ async def _forward_input_task(): input_task = asyncio.create_task(_forward_input_task()) try: - for i, tts in enumerate(self._tts._wrapped_tts): - tts_status = self._tts._status[i] + for i, tts in enumerate(self._fallback_adapter._tts_instances): + tts_status = self._fallback_adapter._status[i] if tts_status.available or all_failed: audio_duration = 0.0 try: @@ -485,23 +496,31 @@ async def _forward_input_task(): Union[str, SynthesizeStream._FlushSentinel] ]() - for text in self._fallback_pending_texts: - for t in text: - new_input_ch.send_nowait(t) + for text in self._pending_segments_chunks: + for chunk in text: + new_input_ch.send_nowait(chunk) new_input_ch.send_nowait(self._FlushSentinel()) - for t in self._fallback_text: - new_input_ch.send_nowait(t) + for chunk in self._current_segment_text: + new_input_ch.send_nowait(chunk) - if self._input_ch.closed: + if input_task.done(): new_input_ch.close() last_segment_id: str | None = None resampler = tts_status.resampler async for synthesized_audio in self._try_synthesize( - tts=tts, input_ch=new_input_ch, recovering=False + tts=tts, + input_ch=new_input_ch, + conn_options=dataclasses.replace( + self._conn_options, + max_retry=self._fallback_adapter._max_retry_per_tts, + timeout=self._fallback_adapter._attempt_timeout, + retry_interval=self._fallback_adapter._retry_interval, + ), + recovering=False, ): audio_duration += synthesized_audio.frame.duration @@ -531,15 +550,14 @@ async def _forward_input_task(): last_segment_id is not None and synthesized_audio.segment_id != last_segment_id ) - ) and self._fallback_pending_texts: + ) and self._pending_segments_chunks: audio_duration = 0.0 + self._pending_segments_chunks.pop(0) last_segment_id = synthesized_audio.segment_id return - except ( - Exception - ): # exceptions already logged inside _try_synthesize + except Exception: if tts_status.available: tts_status.available = False self._tts.emit( @@ -547,43 +565,46 @@ async def _forward_input_task(): AvailabilityChangedEvent(tts=tts, available=False), ) - if self._tts._no_fallback_after_audio_duration is not None: + if ( + self._fallback_adapter._no_fallback_after_audio_duration + is not None + ): if ( audio_duration - >= self._tts._no_fallback_after_audio_duration - and self._fallback_pending_texts + >= self._fallback_adapter._no_fallback_after_audio_duration + and self._pending_segments_chunks ): logger.warning( f"{tts.label} already synthesized {audio_duration}s of audio, ignoring the current segment for the tts fallback" ) return - retry_segments: list[list[str]] = [self._fallback_text.copy()] - if self._total_segments: - retry_segments.insert(0, self._total_segments[-1]) - - self._try_recovery(tts, retry_segments) + self._try_recovery(tts) raise APIConnectionError( "all TTSs failed (%s) after %s seconds" % ( - [tts.label for tts in self._tts._wrapped_tts], + [tts.label for tts in self._fallback_adapter._tts_instances], time.time() - start_time, ) ) finally: await utils.aio.gracefully_cancel(input_task) - def _try_recovery(self, tts: TTS, segments: list[list[str]]) -> None: + def _try_recovery(self, tts: TTS) -> None: assert isinstance(self._tts, FallbackAdapter) - tts_status = self._tts._status[self._tts._wrapped_tts.index(tts)] + retry_segments = [self._current_segment_text.copy()] + if self._total_segments: + retry_segments.insert(0, self._total_segments[-1]) + + tts_status = self._tts._status[self._tts._tts_instances.index(tts)] if tts_status.recovering_task is None or tts_status.recovering_task.done(): async def _recover_tts_task(tts: TTS) -> None: try: input_ch = aio.Chan[Union[str, SynthesizeStream._FlushSentinel]]() - for segment in segments: + for segment in retry_segments: for t in segment: input_ch.send_nowait(t) @@ -592,7 +613,15 @@ async def _recover_tts_task(tts: TTS) -> None: input_ch.close() async for _ in self._try_synthesize( - tts=tts, input_ch=input_ch, recovering=True + tts=tts, + input_ch=input_ch, + recovering=True, + conn_options=dataclasses.replace( + self._conn_options, + max_retry=0, + timeout=self._fallback_adapter._attempt_timeout, + retry_interval=self._fallback_adapter._retry_interval, + ), ): pass diff --git a/livekit-agents/livekit/agents/tts/stream_adapter.py b/livekit-agents/livekit/agents/tts/stream_adapter.py index 5d5e84aca..fbb25df5d 100644 --- a/livekit-agents/livekit/agents/tts/stream_adapter.py +++ b/livekit-agents/livekit/agents/tts/stream_adapter.py @@ -77,18 +77,26 @@ async def _metrics_monitor_task( async def _run(self) -> None: async def _forward_input(): """forward input to vad""" - async for input in self._input_ch: - if isinstance(input, self._FlushSentinel): + async for data in self._input_ch: + if isinstance(data, self._FlushSentinel): self._sent_stream.flush() continue - self._sent_stream.push_text(input) + self._sent_stream.push_text(data) self._sent_stream.end_input() async def _synthesize(): async for ev in self._sent_stream: + last_audio: SynthesizedAudio | None = None async for audio in self._wrapped_tts.synthesize(ev.token): - self._event_ch.send_nowait(audio) + if last_audio is not None: + self._event_ch.send_nowait(last_audio) + + last_audio = audio + + if last_audio is not None: + last_audio.is_final = True + self._event_ch.send_nowait(last_audio) tasks = [ asyncio.create_task(_forward_input()), diff --git a/livekit-agents/livekit/agents/utils/codecs/__init__.py b/livekit-agents/livekit/agents/utils/codecs/__init__.py index 35f19332a..ad2f77b91 100644 --- a/livekit-agents/livekit/agents/utils/codecs/__init__.py +++ b/livekit-agents/livekit/agents/utils/codecs/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .decoder import AudioStreamDecoder, StreamBuffer from .mp3 import Mp3StreamDecoder -__all__ = ["Mp3StreamDecoder"] +__all__ = ["Mp3StreamDecoder", "AudioStreamDecoder", "StreamBuffer"] diff --git a/livekit-agents/livekit/agents/utils/codecs/decoder.py b/livekit-agents/livekit/agents/utils/codecs/decoder.py new file mode 100644 index 000000000..01367c055 --- /dev/null +++ b/livekit-agents/livekit/agents/utils/codecs/decoder.py @@ -0,0 +1,159 @@ +# Copyright 2024 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import io +from typing import AsyncIterator + +from livekit.agents.utils import aio + +try: + # preload to ensure faster startup + import av # noqa +except ImportError: + pass +import threading + +from livekit import rtc + + +class StreamBuffer: + """ + A thread-safe buffer that behaves like an IO stream. + Allows writing from one thread and reading from another. + """ + + def __init__(self): + self._buffer = io.BytesIO() + self._lock = threading.Lock() + self._data_available = threading.Condition(self._lock) + self._eof = False # EOF flag to signal no more writes + + def write(self, data: bytes): + """Write data to the buffer from a writer thread.""" + with self._data_available: # Lock and notify readers + self._buffer.seek(0, io.SEEK_END) # Move to the end + self._buffer.write(data) + self._data_available.notify_all() # Notify waiting readers + + def read(self, size: int = -1) -> bytes: + """Read data from the buffer in a reader thread.""" + + if self._buffer.closed: + return b"" + + with self._data_available: + while True: + self._buffer.seek(0) # Rewind for reading + data = self._buffer.read(size) + + # If data is available, return it + if data: + # Shrink the buffer to remove already-read data + remaining = self._buffer.read() + self._buffer = io.BytesIO(remaining) + return data + + # If EOF is signaled and no data remains, return EOF + if self._eof: + return b"" + + # Wait for more data + self._data_available.wait() + + def end_input(self): + """Signal that no more data will be written.""" + with self._data_available: + self._eof = True + self._data_available.notify_all() + + def close(self): + self._buffer.close() + + +class AudioStreamDecoder: + """A class that can be used to decode audio stream into PCM AudioFrames. + + Decoders are stateful, and it should not be reused across multiple streams. Each decoder + is designed to decode a single stream. + """ + + def __init__(self): + try: + import av # noqa + except ImportError: + raise ImportError( + "You haven't included the 'codecs' optional dependencies. Please install the 'codecs' extra by running `pip install livekit-agents[codecs]`" + ) + + self._output_ch = aio.Chan[rtc.AudioFrame]() + self._closed = False + self._started = False + self._output_finished = False + self._input_buf = StreamBuffer() + self._loop = asyncio.get_event_loop() + + def push(self, chunk: bytes): + self._input_buf.write(chunk) + if not self._started: + self._started = True + self._loop.run_in_executor(None, self._decode_loop) + + def end_input(self): + self._input_buf.end_input() + + def _decode_loop(self): + container = av.open(self._input_buf) + audio_stream = next(s for s in container.streams if s.type == "audio") + resampler = av.AudioResampler( + # convert to signed 16-bit little endian + format="s16", + layout="mono", + rate=audio_stream.rate, + ) + try: + # TODO: handle error where audio stream isn't found + if not audio_stream: + return + for frame in container.decode(audio_stream): + if self._closed: + return + for resampled_frame in resampler.resample(frame): + nchannels = len(resampled_frame.layout.channels) + data = resampled_frame.to_ndarray().tobytes() + self._output_ch.send_nowait( + rtc.AudioFrame( + data=data, + num_channels=nchannels, + sample_rate=resampled_frame.sample_rate, + samples_per_channel=resampled_frame.samples / nchannels, + ) + ) + finally: + self._output_finished = True + + def __aiter__(self) -> AsyncIterator[rtc.AudioFrame]: + return self + + async def __anext__(self) -> rtc.AudioFrame: + if self._output_finished and self._output_ch.empty(): + raise StopAsyncIteration + return await self._output_ch.__anext__() + + async def aclose(self): + if self._closed: + return + self._closed = True + self._input_buf.close() + self._output_ch.close() diff --git a/livekit-agents/livekit/agents/utils/codecs/mp3.py b/livekit-agents/livekit/agents/utils/codecs/mp3.py index 8a6c520c2..2f2321028 100644 --- a/livekit-agents/livekit/agents/utils/codecs/mp3.py +++ b/livekit-agents/livekit/agents/utils/codecs/mp3.py @@ -14,9 +14,12 @@ import ctypes import logging -from importlib import import_module from typing import List +try: + import av # noqa +except ImportError: + pass from livekit import rtc @@ -28,15 +31,28 @@ class Mp3StreamDecoder: def __init__(self): try: - globals()["av"] = import_module("av") + import av except ImportError: raise ImportError( "You haven't included the 'codecs' optional dependencies. Please install the 'codecs' extra by running `pip install livekit-agents[codecs]`" ) - self._codec = av.CodecContext.create("mp3", "r") # noqa def decode_chunk(self, chunk: bytes) -> List[rtc.AudioFrame]: + # Skip ID3v2 header if present + if chunk.startswith(b"ID3"): + # ID3v2 header is 10 bytes long + # The size is encoded in the next 4 bytes (bytes 6-9) + # Each byte only uses 7 bits (most significant bit is always 0) + if len(chunk) >= 10: + size = ( + ((chunk[6] & 0x7F) << 21) + | ((chunk[7] & 0x7F) << 14) + | ((chunk[8] & 0x7F) << 7) + | (chunk[9] & 0x7F) + ) + chunk = chunk[10 + size :] + packets = self._codec.parse(chunk) result: List[rtc.AudioFrame] = [] for packet in packets: diff --git a/livekit-agents/livekit/agents/utils/images/image.py b/livekit-agents/livekit/agents/utils/images/image.py index 15755284d..dd9aac739 100644 --- a/livekit-agents/livekit/agents/utils/images/image.py +++ b/livekit-agents/livekit/agents/utils/images/image.py @@ -25,15 +25,42 @@ @dataclass class EncodeOptions: + """Options for encoding rtc.VideoFrame to portable image formats.""" + format: Literal["JPEG", "PNG"] = "JPEG" + """The format to encode the image.""" + resize_options: Optional["ResizeOptions"] = None + """Options for resizing the image.""" + + quality: Optional[int] = 75 + """Image compression quality, 0-100. Only applies to JPEG.""" @dataclass class ResizeOptions: + """Options for resizing rtc.VideoFrame as part of encoding to a portable image format.""" + width: int + """The desired resize width (in)""" + height: int - strategy: Literal["center_aspect_fit", "center_aspect_cover", "skew"] + """The desired height to resize the image to.""" + + strategy: Literal[ + "center_aspect_fit", + "center_aspect_cover", + "scale_aspect_fit", + "scale_aspect_cover", + "skew", + ] + """The strategy to use when resizing the image: + - center_aspect_fit: Fit the image into the provided dimensions, with letterboxing + - center_aspect_cover: Fill the provided dimensions, with cropping + - scale_aspect_fit: Fit the image into the provided dimensions, preserving its original aspect ratio + - scale_aspect_cover: Fill the provided dimensions, preserving its original aspect ratio (image will be larger than the provided dimensions) + - skew: Precisely resize the image to the provided dimensions + """ def import_pil(): @@ -46,12 +73,19 @@ def import_pil(): ) -def encode(frame: rtc.VideoFrame, options: EncodeOptions): +def encode(frame: rtc.VideoFrame, options: EncodeOptions) -> bytes: + """Encode a rtc.VideoFrame to a portable image format (JPEG or PNG). + + See EncodeOptions for more details. + """ import_pil() img = _image_from_frame(frame) resized = _resize_image(img, options) buffer = io.BytesIO() - resized.save(buffer, options.format) + kwargs = {} + if options.format == "JPEG" and options.quality is not None: + kwargs["quality"] = options.quality + resized.save(buffer, options.format, **kwargs) buffer.seek(0) return buffer.read() @@ -83,10 +117,11 @@ def _resize_image(image: Any, options: EncodeOptions): # If the new image is wider than the original if resize_opts.width / resize_opts.height > image.width / image.height: - new_width = resize_opts.width - new_height = int(image.height * (resize_opts.width / image.width)) + new_height = resize_opts.height + new_width = int(image.width * (resize_opts.height / image.height)) resized = image.resize((new_width, new_height)) + Image.Image.paste( result, resized, @@ -118,5 +153,27 @@ def _resize_image(image: Any, options: EncodeOptions): ), ) return result + elif resize_opts.strategy == "scale_aspect_fill": + # Start with assuming width is the limiting dimension + new_width = resize_opts.width + new_height = int(image.height * (resize_opts.width / image.width)) + + # If height is under the limit, scale based on height instead + if new_height < resize_opts.height: + new_height = resize_opts.height + new_width = int(image.width * (resize_opts.height / image.height)) + + return image.resize((new_width, new_height)) + elif resize_opts.strategy == "scale_aspect_fit": + # Start with assuming width is the limiting dimension + new_width = resize_opts.width + new_height = int(image.height * (resize_opts.width / image.width)) + + # If height would exceed the limit, scale based on height instead + if new_height > resize_opts.height: + new_height = resize_opts.height + new_width = int(image.width * (resize_opts.height / image.height)) + + return image.resize((new_width, new_height)) raise ValueError(f"Unknown resize strategy: {resize_opts.strategy}") diff --git a/livekit-agents/livekit/agents/version.py b/livekit-agents/livekit/agents/version.py index 3debd106b..0696f486e 100644 --- a/livekit-agents/livekit/agents/version.py +++ b/livekit-agents/livekit/agents/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.11.3" +__version__ = "0.12.6" diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index a9a6c39b3..54ad75470 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -43,6 +43,7 @@ from . import http_server, ipc, utils from ._exceptions import AssignmentTimeoutError +from .inference_runner import _InferenceRunner from .job import ( JobAcceptArguments, JobContext, @@ -158,6 +159,15 @@ class WorkerOptions: Defaults to 0.75 on "production" mode, and is disabled in "development" mode. """ + + job_memory_warn_mb: float = 300 + """Memory warning threshold in MB. If the job process exceeds this limit, a warning will be logged.""" + job_memory_limit_mb: float = 0 + """Maximum memory usage for a job in MB, the job process will be killed if it exceeds this limit. + Defaults to 0 (disabled). + """ + + """Number of idle processes to keep warm.""" num_idle_processes: int | _WorkerEnvOption[int] = _WorkerEnvOption( dev_default=0, prod_default=3 ) @@ -234,6 +244,15 @@ def __init__( "api_secret is required, or add LIVEKIT_API_SECRET in your environment" ) + if ( + opts.job_memory_limit_mb > 0 + and opts.job_executor_type != JobExecutorType.PROCESS + ): + logger.warning( + "max_job_memory_usage is only supported for process-based job executors, " + "ignoring max_job_memory_usage" + ) + self._opts = opts self._loop = loop or asyncio.get_event_loop() @@ -248,6 +267,26 @@ def __init__( # using spawn context for all platforms. We may have further optimizations for # Linux with forkserver, but for now, this is the safest option mp_ctx = mp.get_context("spawn") + + self._inference_executor: ( + ipc.inference_proc_executor.InferenceProcExecutor | None + ) = None + if len(_InferenceRunner.registered_runners) > 0: + self._inference_executor = ( + ipc.inference_proc_executor.InferenceProcExecutor( + runners=_InferenceRunner.registered_runners, + initialize_timeout=30, + close_timeout=5, + memory_warn_mb=2000, + memory_limit_mb=0, # no limit + ping_interval=5, + ping_timeout=60, + high_ping_threshold=2.5, + mp_ctx=mp_ctx, + loop=self._loop, + ) + ) + self._proc_pool = ipc.proc_pool.ProcPool( initialize_process_fnc=opts.prewarm_fnc, job_entrypoint_fnc=opts.entrypoint_fnc, @@ -256,13 +295,13 @@ def __init__( ), loop=self._loop, job_executor_type=opts.job_executor_type, + inference_executor=self._inference_executor, mp_ctx=mp_ctx, initialize_timeout=opts.initialize_process_timeout, close_timeout=opts.shutdown_process_timeout, + memory_warn_mb=opts.job_memory_warn_mb, + memory_limit_mb=opts.job_memory_limit_mb, ) - self._proc_pool.on("process_started", self._on_process_started) - self._proc_pool.on("process_closed", self._on_process_closed) - self._proc_pool.on("process_job_launched", self._on_process_job_launched) self._previous_status = agent.WorkerStatus.WS_AVAILABLE @@ -285,7 +324,22 @@ async def run(self): extra={"version": __version__, "rtc-version": rtc.__version__}, ) + if self._inference_executor is not None: + logger.info("starting inference executor") + await self._inference_executor.start() + await self._inference_executor.initialize() + self._closed = False + + def _update_job_status(proc: ipc.job_executor.JobExecutor) -> None: + t = self._loop.create_task(self._update_job_status(proc)) + self._tasks.add(t) + t.add_done_callback(self._tasks.discard) + + self._proc_pool.on("process_started", _update_job_status) + self._proc_pool.on("process_closed", _update_job_status) + self._proc_pool.on("process_job_launched", _update_job_status) + self._proc_pool.start() self._api = api.LiveKitAPI( self._opts.ws_url, self._opts.api_key, self._opts.api_secret @@ -372,6 +426,10 @@ async def aclose(self) -> None: self._main_task.cancel() await self._proc_pool.aclose() + + if self._inference_executor is not None: + await self._inference_executor.aclose() + await self._http_session.close() await self._http_server.aclose() await self._api.aclose() @@ -563,6 +621,7 @@ async def _reload_jobs(self, jobs: list[RunningJobInfo]) -> None: job=aj.job, url=url, token=jwt.encode(decoded, self._opts.api_secret, algorithm="HS256"), + worker_id=aj.worker_id, ) await self._proc_pool.launch_job(running_info) @@ -634,6 +693,7 @@ async def _on_accept(args: JobAcceptArguments) -> None: job=msg.job, url=job_assign.url or self._opts.ws_url, token=job_assign.token, + worker_id=self._id, ) await self._proc_pool.launch_job(running_info) @@ -690,15 +750,6 @@ async def _handle_termination(self, msg: agent.JobTermination): return await proc.aclose() - def _on_process_closed(self, proc: ipc.job_executor.JobExecutor) -> None: - self._update_job_status_sync(proc) - - def _on_process_started(self, proc: ipc.job_executor.JobExecutor) -> None: - self._update_job_status_sync(proc) - - def _on_process_job_launched(self, proc: ipc.job_executor.JobExecutor) -> None: - self._update_job_status_sync(proc) - async def _update_worker_status(self): job_cnt = len(self.active_jobs) if self._draining: @@ -756,28 +807,19 @@ def load_fnc(): with contextlib.suppress(utils.aio.ChanClosed): await self._queue_msg(msg) - def _update_job_status_sync(self, proc: ipc.job_executor.JobExecutor) -> None: - t = self._loop.create_task(self._update_job_status(proc)) - self._tasks.add(t) - t.add_done_callback(self._tasks.discard) - async def _update_job_status(self, proc: ipc.job_executor.JobExecutor) -> None: job_info = proc.running_job - if not job_info: + if job_info is None: return + status: agent.JobStatus = agent.JobStatus.JS_RUNNING - if proc.run_status == ipc.job_executor.RunStatus.FINISHED_FAILED: + if proc.status == ipc.job_executor.JobStatus.FAILED: status = agent.JobStatus.JS_FAILED - elif proc.run_status == ipc.job_executor.RunStatus.FINISHED_CLEAN: + elif proc.status == ipc.job_executor.JobStatus.SUCCESS: status = agent.JobStatus.JS_SUCCESS - elif proc.run_status == ipc.job_executor.RunStatus.STARTING: - status = agent.JobStatus.JS_PENDING - - error: str | None = None - if proc.exception: - error = str(proc.exception) - update = agent.UpdateJobStatus( - job_id=job_info.job.id, status=status, error=error - ) + elif proc.status == ipc.job_executor.JobStatus.RUNNING: + status = agent.JobStatus.JS_RUNNING + + update = agent.UpdateJobStatus(job_id=job_info.job.id, status=status, error="") msg = agent.WorkerMessage(update_job=update) await self._queue_msg(msg) diff --git a/livekit-agents/package.json b/livekit-agents/package.json index c869cf3d8..c321ac852 100644 --- a/livekit-agents/package.json +++ b/livekit-agents/package.json @@ -1,5 +1,5 @@ { "name": "livekit-agents", "private": true, - "version": "0.11.3" + "version": "0.12.6" } diff --git a/livekit-agents/setup.py b/livekit-agents/setup.py index 54d6d54cb..9ff541808 100644 --- a/livekit-agents/setup.py +++ b/livekit-agents/setup.py @@ -48,7 +48,7 @@ python_requires=">=3.9.0", install_requires=[ "click~=8.1", - "livekit>=0.17.6", + "livekit>=0.18.1", "livekit-api~=0.8", "livekit-protocol~=0.7", "protobuf>=3", @@ -66,8 +66,8 @@ ':sys_platform!="win32"': [ "aiodns~=3.2" ], # use default aiohttp resolver on windows - "codecs": ["av>=11.0.0"], - "images": ["pillow~=10.3.0"], + "codecs": ["av>=12.0.0", "numpy>=1.26.0"], + "images": ["pillow>=10.3.0"], }, package_data={"livekit.agents": ["py.typed"]}, project_urls={ diff --git a/livekit-plugins/install_local.sh b/livekit-plugins/install_local.sh new file mode 100755 index 000000000..3e6a1cee4 --- /dev/null +++ b/livekit-plugins/install_local.sh @@ -0,0 +1,22 @@ +#!/bin/bash +set -e + +# Get the directory where the script is located +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +pip install \ + "${SCRIPT_DIR}/livekit-plugins-anthropic" \ + "${SCRIPT_DIR}/livekit-plugins-assemblyai" \ + "${SCRIPT_DIR}/livekit-plugins-azure" \ + "${SCRIPT_DIR}/livekit-plugins-cartesia" \ + "${SCRIPT_DIR}/livekit-plugins-deepgram" \ + "${SCRIPT_DIR}/livekit-plugins-elevenlabs" \ + "${SCRIPT_DIR}/livekit-plugins-fal" \ + "${SCRIPT_DIR}/livekit-plugins-google" \ + "${SCRIPT_DIR}/livekit-plugins-llama-index" \ + "${SCRIPT_DIR}/livekit-plugins-nltk" \ + "${SCRIPT_DIR}/livekit-plugins-openai" \ + "${SCRIPT_DIR}/livekit-plugins-rag" \ + "${SCRIPT_DIR}/livekit-plugins-playai" \ + "${SCRIPT_DIR}/livekit-plugins-silero" \ + "${SCRIPT_DIR}/livekit-plugins-turn-detector" diff --git a/livekit-plugins/install_plugins_editable.sh b/livekit-plugins/install_plugins_editable.sh index 0072e5a17..9a5d9960b 100755 --- a/livekit-plugins/install_plugins_editable.sh +++ b/livekit-plugins/install_plugins_editable.sh @@ -16,6 +16,8 @@ pip install -e ./livekit-plugins-minimal --config-settings editable_mode=strict pip install -e ./livekit-plugins-nltk --config-settings editable_mode=strict pip install -e ./livekit-plugins-openai --config-settings editable_mode=strict pip install -e ./livekit-plugins-rag --config-settings editable_mode=strict +pip install -e ./livekit-plugins-llama-index --config-settings editable_mode=strict +pip install -e ./livekit-plugins-turn-detector --config-settings editable_mode=strict pip install -e ./livekit-plugins-silero --config-settings editable_mode=strict pip install -e ./livekit-plugins-browser --config-settings editable_mode=strict -pip install -e ./livekit-plugins-llama-index --config-settings editable_mode=strict + diff --git a/livekit-plugins/livekit-plugins-anthropic/CHANGELOG.md b/livekit-plugins/livekit-plugins-anthropic/CHANGELOG.md index a6d8931d2..3b75922f3 100644 --- a/livekit-plugins/livekit-plugins-anthropic/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-anthropic/CHANGELOG.md @@ -1,5 +1,47 @@ # livekit-plugins-anthropic +## 0.2.9 + +### Patch Changes + +- improved handling of LLM errors, do not retry if already began - [#1298](https://github.com/livekit/agents/pull/1298) ([@davidzhao](https://github.com/davidzhao)) + +## 0.2.8 + +### Patch Changes + +- Moved create_ai_function_info to function_context.py for better reusability and reduce repetation - [#1260](https://github.com/livekit/agents/pull/1260) ([@jayeshp19](https://github.com/jayeshp19)) + +- Add support for OpenAI's "detail" parameter to ChatImage - [#1213](https://github.com/livekit/agents/pull/1213) ([@bcherry](https://github.com/bcherry)) + + Add support for data URLs on ChatImage in the Anthropic plugin. + +- fix: correctly parse function argument types - [#1221](https://github.com/livekit/agents/pull/1221) ([@jayeshp19](https://github.com/jayeshp19)) + +- Fix center_aspect_fit bug, add scale_aspect_fit and scale_aspect_fill resizing options. - [#1222](https://github.com/livekit/agents/pull/1222) ([@bcherry](https://github.com/bcherry)) + + Make scale_aspect_fit the new default resizing option for video frames. + +## 0.2.7 + +### Patch Changes + +- fix: return structured output from func calls - [#1187](https://github.com/livekit/agents/pull/1187) ([@jayeshp19](https://github.com/jayeshp19)) + +## 0.2.6 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.2.5 + +### Patch Changes + +- support for custom tool use in LLMs - [#1102](https://github.com/livekit/agents/pull/1102) ([@jayeshp19](https://github.com/jayeshp19)) + +- feat: llm retry & llm.FallbackAdapter - [#1132](https://github.com/livekit/agents/pull/1132) ([@theomonnom](https://github.com/theomonnom)) + ## 0.2.4 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py b/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py index 1ae45141b..3af490211 100644 --- a/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py +++ b/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py @@ -19,7 +19,16 @@ import json import os from dataclasses import dataclass -from typing import Any, Awaitable, List, Tuple, get_args, get_origin +from typing import ( + Any, + Awaitable, + List, + Literal, + Union, + cast, + get_args, + get_origin, +) import httpx from livekit import rtc @@ -30,6 +39,12 @@ llm, utils, ) +from livekit.agents.llm import ToolChoice +from livekit.agents.llm.function_context import ( + _create_ai_function_info, + _is_optional_type, +) +from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions import anthropic @@ -44,6 +59,8 @@ class LLMOptions: model: str | ChatModels user: str | None temperature: float | None + parallel_tool_calls: bool | None + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] | None class LLM(llm.LLM): @@ -56,6 +73,8 @@ def __init__( user: str | None = None, client: anthropic.AsyncClient | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> None: """ Create a new instance of Anthropic LLM. @@ -70,7 +89,13 @@ def __init__( if api_key is None: raise ValueError("Anthropic API key is required") - self._opts = LLMOptions(model=model, user=user, temperature=temperature) + self._opts = LLMOptions( + model=model, + user=user, + temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, + ) self._client = client or anthropic.AsyncClient( api_key=api_key, base_url=base_url, @@ -89,13 +114,20 @@ def chat( self, *, chat_ctx: llm.ChatContext, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, fnc_ctx: llm.FunctionContext | None = None, temperature: float | None = None, n: int | None = 1, parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] + | None = None, ) -> "LLMStream": if temperature is None: temperature = self._opts.temperature + if parallel_tool_calls is None: + parallel_tool_calls = self._opts.parallel_tool_calls + if tool_choice is None: + tool_choice = self._opts.tool_choice opts: dict[str, Any] = dict() if fnc_ctx and len(fnc_ctx.ai_functions) > 0: @@ -104,9 +136,20 @@ def chat( fncs_desc.append(_build_function_description(fnc)) opts["tools"] = fncs_desc - - if fnc_ctx and parallel_tool_calls is not None: - opts["parallel_tool_calls"] = parallel_tool_calls + if tool_choice is not None: + anthropic_tool_choice: dict[str, Any] = {"type": "auto"} + if isinstance(tool_choice, ToolChoice): + if tool_choice.type == "function": + anthropic_tool_choice = { + "type": "tool", + "name": tool_choice.name, + } + elif isinstance(tool_choice, str): + if tool_choice == "required": + anthropic_tool_choice = {"type": "any"} + if parallel_tool_calls is not None and parallel_tool_calls is False: + anthropic_tool_choice["disable_parallel_tool_use"] = True + opts["tool_choice"] = anthropic_tool_choice latest_system_message = _latest_system_message(chat_ctx) anthropic_ctx = _build_anthropic_context(chat_ctx.messages, id(self)) @@ -124,7 +167,11 @@ def chat( ) return LLMStream( - self, anthropic_stream=stream, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx + self, + anthropic_stream=stream, + chat_ctx=chat_ctx, + fnc_ctx=fnc_ctx, + conn_options=conn_options, ) @@ -138,8 +185,11 @@ def __init__( ], chat_ctx: llm.ChatContext, fnc_ctx: llm.FunctionContext | None, + conn_options: APIConnectOptions, ) -> None: - super().__init__(llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) + super().__init__( + llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options + ) self._awaitable_anthropic_stream = anthropic_stream self._anthropic_stream: ( anthropic.AsyncStream[anthropic.types.RawMessageStreamEvent] | None @@ -155,7 +205,8 @@ def __init__( self._input_tokens = 0 self._output_tokens = 0 - async def _main_task(self) -> None: + async def _run(self) -> None: + retryable = True try: if not self._anthropic_stream: self._anthropic_stream = await self._awaitable_anthropic_stream @@ -165,6 +216,7 @@ async def _main_task(self) -> None: chat_chunk = self._parse_event(event) if chat_chunk is not None: self._event_ch.send_nowait(chat_chunk) + retryable = False self._event_ch.send_nowait( llm.ChatChunk( @@ -177,7 +229,7 @@ async def _main_task(self) -> None: ) ) except anthropic.APITimeoutError: - raise APITimeoutError() + raise APITimeoutError(retryable=retryable) except anthropic.APIStatusError as e: raise APIStatusError( e.message, @@ -186,7 +238,7 @@ async def _main_task(self) -> None: body=e.body, ) except Exception as e: - raise APIConnectionError() from e + raise APIConnectionError(retryable=retryable) from e def _parse_event( self, event: anthropic.types.RawMessageStreamEvent @@ -356,8 +408,10 @@ def _build_anthropic_message( return a_msg elif msg.role == "tool": + if isinstance(msg.content, dict): + msg.content = json.dumps(msg.content) if not isinstance(msg.content, str): - logger.warning("tool message content is not a string") + logger.warning("tool message content is not a string or dict") return None if not msg.tool_call_id: return None @@ -379,11 +433,36 @@ def _build_anthropic_message( def _build_anthropic_image_content( image: llm.ChatImage, cache_key: Any ) -> anthropic.types.ImageBlockParam: - if isinstance(image.image, str): # image url - logger.warning( - "image url not supported by anthropic, skipping image '%s'", image.image - ) - elif isinstance(image.image, rtc.VideoFrame): # VideoFrame + if isinstance(image.image, str): # image is a URL + if not image.image.startswith("data:"): + raise ValueError("LiveKit Anthropic Plugin: Image URLs must be data URLs") + + try: + header, b64_data = image.image.split(",", 1) + media_type = header.split(";")[0].split(":")[1] + + supported_types = {"image/jpeg", "image/png", "image/webp", "image/gif"} + if media_type not in supported_types: + raise ValueError( + f"LiveKit Anthropic Plugin: Unsupported media type {media_type}. Must be jpeg, png, webp, or gif" + ) + + return { + "type": "image", + "source": { + "type": "base64", + "data": b64_data, + "media_type": cast( + Literal["image/jpeg", "image/png", "image/gif", "image/webp"], + media_type, + ), + }, + } + except (ValueError, IndexError) as e: + raise ValueError( + f"LiveKit Anthropic Plugin: Invalid image data URL {str(e)}" + ) + elif isinstance(image.image, rtc.VideoFrame): # image is a VideoFrame if cache_key not in image._cache: # inside our internal implementation, we allow to put extra metadata to # each ChatImage (avoid to reencode each time we do a chatcompletion request) @@ -392,7 +471,7 @@ def _build_anthropic_image_content( opts.resize_options = utils.images.ResizeOptions( width=image.inference_width, height=image.inference_height, - strategy="center_aspect_fit", + strategy="scale_aspect_fit", ) encoded_data = utils.images.encode(image.image, opts) @@ -407,65 +486,8 @@ def _build_anthropic_image_content( }, } - raise ValueError(f"unknown image type {type(image.image)}") - - -def _create_ai_function_info( - fnc_ctx: llm.function_context.FunctionContext, - tool_call_id: str, - fnc_name: str, - raw_arguments: str, # JSON string -) -> llm.function_context.FunctionCallInfo: - if fnc_name not in fnc_ctx.ai_functions: - raise ValueError(f"AI function {fnc_name} not found") - - parsed_arguments: dict[str, Any] = {} - try: - if raw_arguments: # ignore empty string - parsed_arguments = json.loads(raw_arguments) - except json.JSONDecodeError: - raise ValueError( - f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}" - ) - - fnc_info = fnc_ctx.ai_functions[fnc_name] - - # Ensure all necessary arguments are present and of the correct type. - sanitized_arguments: dict[str, Any] = {} - for arg_info in fnc_info.arguments.values(): - if arg_info.name not in parsed_arguments: - if arg_info.default is inspect.Parameter.empty: - raise ValueError( - f"AI function {fnc_name} missing required argument {arg_info.name}" - ) - continue - - arg_value = parsed_arguments[arg_info.name] - if get_origin(arg_info.type) is not None: - if not isinstance(arg_value, list): - raise ValueError( - f"AI function {fnc_name} argument {arg_info.name} should be a list" - ) - - inner_type = get_args(arg_info.type)[0] - sanitized_value = [ - _sanitize_primitive( - value=v, expected_type=inner_type, choices=arg_info.choices - ) - for v in arg_value - ] - else: - sanitized_value = _sanitize_primitive( - value=arg_value, expected_type=arg_info.type, choices=arg_info.choices - ) - - sanitized_arguments[arg_info.name] = sanitized_value - - return llm.function_context.FunctionCallInfo( - tool_call_id=tool_call_id, - raw_arguments=raw_arguments, - function_info=fnc_info, - arguments=sanitized_arguments, + raise ValueError( + "LiveKit Anthropic Plugin: ChatImage must be an rtc.VideoFrame or a data URL" ) @@ -492,8 +514,10 @@ def type2str(t: type) -> str: if arg_info.description: p["description"] = arg_info.description - if get_origin(arg_info.type) is list: - inner_type = get_args(arg_info.type)[0] + is_optional, inner_th = _is_optional_type(arg_info.type) + + if get_origin(inner_th) is list: + inner_type = get_args(inner_th)[0] p["type"] = "array" p["items"] = {} p["items"]["type"] = type2str(inner_type) @@ -501,7 +525,7 @@ def type2str(t: type) -> str: if arg_info.choices: p["items"]["enum"] = arg_info.choices else: - p["type"] = type2str(arg_info.type) + p["type"] = type2str(inner_th) if arg_info.choices: p["enum"] = arg_info.choices @@ -517,31 +541,3 @@ def type2str(t: type) -> str: "description": fnc_info.description, "input_schema": input_schema, } - - -def _sanitize_primitive( - *, value: Any, expected_type: type, choices: Tuple[Any] | None -) -> Any: - if expected_type is str: - if not isinstance(value, str): - raise ValueError(f"expected str, got {type(value)}") - elif expected_type in (int, float): - if not isinstance(value, (int, float)): - raise ValueError(f"expected number, got {type(value)}") - - if expected_type is int: - if value % 1 != 0: - raise ValueError("expected int, got float") - - value = int(value) - elif expected_type is float: - value = float(value) - - elif expected_type is bool: - if not isinstance(value, bool): - raise ValueError(f"expected bool, got {type(value)}") - - if choices and value not in choices: - raise ValueError(f"invalid value {value}, not in {choices}") - - return value diff --git a/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/version.py b/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/version.py index f7f2274ac..bd4a8d004 100644 --- a/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/version.py +++ b/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.4" +__version__ = "0.2.9" diff --git a/livekit-plugins/livekit-plugins-anthropic/package.json b/livekit-plugins/livekit-plugins-anthropic/package.json index 0a9be8b4c..eb8866886 100644 --- a/livekit-plugins/livekit-plugins-anthropic/package.json +++ b/livekit-plugins/livekit-plugins-anthropic/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-anthropic", "private": true, - "version": "0.2.4" + "version": "0.2.9" } diff --git a/livekit-plugins/livekit-plugins-anthropic/setup.py b/livekit-plugins/livekit-plugins-anthropic/setup.py index 5a21aeb5c..4d9c3a1ba 100644 --- a/livekit-plugins/livekit-plugins-anthropic/setup.py +++ b/livekit-plugins/livekit-plugins-anthropic/setup.py @@ -49,7 +49,7 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents>=0.11", "anthropic>=0.34"], + install_requires=["livekit-agents>=0.12.3", "anthropic>=0.34"], package_data={"livekit.plugins.anthropic": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io", diff --git a/livekit-plugins/livekit-plugins-assemblyai/CHANGELOG.md b/livekit-plugins/livekit-plugins-assemblyai/CHANGELOG.md index 4e2653349..71d63e941 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-assemblyai/CHANGELOG.md @@ -1,5 +1,25 @@ # livekit-plugins-assemblyai +## 0.2.2 + +### Patch Changes + +- fix: Ensure STT exceptions are being propagated - [#1291](https://github.com/livekit/agents/pull/1291) ([@davidzhao](https://github.com/davidzhao)) + +- assemblyai: encode boost words - [#1284](https://github.com/livekit/agents/pull/1284) ([@jmugicagonz](https://github.com/jmugicagonz)) + +## 0.2.1 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.1.1 + +### Patch Changes + +- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom)) + ## 0.1.0 ### Minor Changes diff --git a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/py.typed b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py index 861c2774c..40c359fd8 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py +++ b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py @@ -19,12 +19,20 @@ import dataclasses import json import os +import weakref from dataclasses import dataclass from typing import List, Literal, Optional from urllib.parse import urlencode import aiohttp -from livekit.agents import stt, utils +from livekit.agents import ( + DEFAULT_API_CONNECT_OPTIONS, + APIConnectOptions, + APIStatusError, + stt, + utils, +) +from livekit.agents.stt import SpeechEvent from livekit.agents.utils import AudioBuffer from .log import logger @@ -40,14 +48,14 @@ @dataclass class STTOptions: - sample_rate: Optional[int] = None + sample_rate: int + buffer_size_seconds: float word_boost: Optional[List[str]] = None encoding: Optional[Literal["pcm_s16le", "pcm_mulaw"]] = None disable_partial_transcripts: bool = False enable_extra_session_information: bool = False end_utterance_silence_threshold: Optional[int] = None # Buffer to collect frames to send to AssemblyAI - buffer_size_seconds: Optional[float] = None def __post_init__(self): if self.encoding not in (None, "pcm_s16le", "pcm_mulaw"): @@ -59,9 +67,9 @@ def __init__( self, *, api_key: Optional[str] = None, - sample_rate: Optional[int] = 16000, + sample_rate: int = 16000, word_boost: Optional[List[str]] = None, - encoding: Optional[str] = "pcm_s16le", + encoding: Optional[Literal["pcm_s16le", "pcm_mulaw"]] = "pcm_s16le", disable_partial_transcripts: bool = False, enable_extra_session_information: bool = False, end_utterance_silence_threshold: Optional[int] = 500, @@ -93,6 +101,7 @@ def __init__( end_utterance_silence_threshold=end_utterance_silence_threshold, ) self._session = http_session + self._streams = weakref.WeakSet[SpeechStream]() @property def session(self) -> aiohttp.ClientSession: @@ -102,15 +111,10 @@ def session(self) -> aiohttp.ClientSession: async def _recognize_impl( self, - *, buffer: AudioBuffer, - ) -> stt.SpeechEvent: - raise NotImplementedError("Not implemented") - - async def recognize( - self, *, - buffer: AudioBuffer, + language: str | None, + conn_options: APIConnectOptions, ) -> stt.SpeechEvent: raise NotImplementedError("Not implemented") @@ -118,14 +122,49 @@ def stream( self, *, language: Optional[str] = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> "SpeechStream": config = dataclasses.replace(self._opts) - return SpeechStream( - stt_=self, + stream = SpeechStream( + stt=self, + conn_options=conn_options, opts=config, api_key=self._api_key, http_session=self.session, ) + self._streams.add(stream) + return stream + + def update_options( + self, + *, + disable_partial_transcripts: Optional[bool] = None, + word_boost: Optional[List[str]] = None, + end_utterance_silence_threshold: Optional[int] = None, + enable_extra_session_information: Optional[bool] = None, + buffer_size_seconds: Optional[float] = None, + ): + if disable_partial_transcripts is not None: + self._opts.disable_partial_transcripts = disable_partial_transcripts + if word_boost is not None: + self._opts.word_boost = word_boost + if end_utterance_silence_threshold is not None: + self._opts.end_utterance_silence_threshold = end_utterance_silence_threshold + if enable_extra_session_information is not None: + self._opts.enable_extra_session_information = ( + enable_extra_session_information + ) + if buffer_size_seconds is not None: + self._opts.buffer_size_seconds = buffer_size_seconds + + for stream in self._streams: + stream.update_options( + disable_partial_transcripts=disable_partial_transcripts, + word_boost=word_boost, + end_utterance_silence_threshold=end_utterance_silence_threshold, + enable_extra_session_information=enable_extra_session_information, + buffer_size_seconds=buffer_size_seconds, + ) class SpeechStream(stt.SpeechStream): @@ -134,88 +173,59 @@ class SpeechStream(stt.SpeechStream): def __init__( self, - stt_: STT, + *, + stt: STT, opts: STTOptions, + conn_options: APIConnectOptions, api_key: str, http_session: aiohttp.ClientSession, - num_channels: int = 1, - max_retry: int = 32, ) -> None: - super().__init__(stt=stt_, sample_rate=opts.sample_rate) + super().__init__( + stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate + ) self._opts = opts - self._num_channels = num_channels self._api_key = api_key self._session = http_session - self._max_retry = max_retry - self._speech_duration = 0 - - if self._num_channels != 1: - raise ValueError( - f"AssemblyAI only supports mono audio, but a `num_channels` of {self._num_channels} was provided" - ) + self._speech_duration: float = 0 # keep a list of final transcripts to combine them inside the END_OF_SPEECH event - self._final_events: List[stt.SpeechEvent] = [] + self._final_events: List[SpeechEvent] = [] + self._reconnect_event = asyncio.Event() - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - await self._run(self._max_retry) + def update_options( + self, + *, + disable_partial_transcripts: Optional[bool] = None, + word_boost: Optional[List[str]] = None, + end_utterance_silence_threshold: Optional[int] = None, + enable_extra_session_information: Optional[bool] = None, + buffer_size_seconds: Optional[float] = None, + ): + if disable_partial_transcripts is not None: + self._opts.disable_partial_transcripts = disable_partial_transcripts + if word_boost is not None: + self._opts.word_boost = word_boost + if end_utterance_silence_threshold is not None: + self._opts.end_utterance_silence_threshold = end_utterance_silence_threshold + if enable_extra_session_information is not None: + self._opts.enable_extra_session_information = ( + enable_extra_session_information + ) + if buffer_size_seconds is not None: + self._opts.buffer_size_seconds = buffer_size_seconds - @utils.log_exceptions(logger=logger) - async def _run(self, max_retry: int) -> None: + self._reconnect_event.set() + + async def _run(self) -> None: """ Run a single websocket connection to AssemblyAI and make sure to reconnect when something went wrong. """ - retry_count = 0 - while self._input_ch.qsize() or not self._input_ch.closed: - try: - live_config = { - "sample_rate": self._opts.sample_rate, - "word_boost": self._opts.word_boost, - "encoding": self._opts.encoding, - "disable_partial_transcripts": self._opts.disable_partial_transcripts, - "enable_extra_session_information": self._opts.enable_extra_session_information, - } - - headers = { - "Authorization": self._api_key, - "Content-Type": "application/json", - } - - ws_url = "wss://api.assemblyai.com/v2/realtime/ws" - filtered_config = { - k: v for k, v in live_config.items() if v is not None - } - url = f"{ws_url}?{urlencode(filtered_config).lower()}" - ws = await self._session.ws_connect(url, headers=headers) - retry_count = 0 # connected successfully, reset the retry_count - - await self._run_ws(ws) - except Exception: - # Something went wrong, retry the connection - if retry_count >= max_retry: - logger.error( - f"failed to connect to AssemblyAI after {max_retry} tries" - ) - break - - retry_delay = min(retry_count * 2, 10) # max 10s - retry_count += 1 # increment after calculating the delay, the first retry should happen directly - logger.info( - f"AssemblyAI connection failed, retrying in {retry_delay}s", - ) - await asyncio.sleep(retry_delay) - - async def _run_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None: - """ - This method can throw ws errors, these are handled inside the _run method - """ closing_ws = False - async def send_task(): + async def send_task(ws: aiohttp.ClientWebSocketResponse): nonlocal closing_ws if self._opts.end_utterance_silence_threshold: @@ -232,7 +242,7 @@ async def send_task(): ) audio_bstream = utils.audio.AudioByteStream( sample_rate=self._opts.sample_rate, - num_channels=self._num_channels, + num_channels=1, samples_per_channel=samples_per_buffer, ) @@ -252,7 +262,7 @@ async def send_task(): closing_ws = True await ws.send_str(SpeechStream._CLOSE_MSG) - async def recv_task(): + async def recv_task(ws: aiohttp.ClientWebSocketResponse): nonlocal closing_ws while True: try: @@ -270,7 +280,7 @@ async def recv_task(): if closing_ws: # close is expected, see SpeechStream.aclose return - raise Exception( + raise APIStatusError( "AssemblyAI connection closed unexpectedly", ) # this will trigger a reconnection, see the _run loop @@ -285,15 +295,57 @@ async def recv_task(): except Exception: logger.exception("failed to process AssemblyAI message") - tasks = [ - asyncio.create_task(send_task()), - asyncio.create_task(recv_task()), - ] + ws: aiohttp.ClientWebSocketResponse | None = None + + while True: + try: + ws = await self._connect_ws() + tasks = [ + asyncio.create_task(send_task(ws)), + asyncio.create_task(recv_task(ws)), + ] + wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) + + try: + done, _ = await asyncio.wait( + [asyncio.gather(*tasks), wait_reconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) # type: ignore + for task in done: + if task != wait_reconnect_task: + task.result() + + if wait_reconnect_task not in done: + break - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) + self._reconnect_event.clear() + finally: + await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task) + finally: + if ws is not None: + await ws.close() + + async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: + live_config = { + "sample_rate": self._opts.sample_rate, + "word_boost": json.dumps(self._opts.word_boost) + if self._opts.word_boost is not None + else None, + "encoding": self._opts.encoding, + "disable_partial_transcripts": self._opts.disable_partial_transcripts, + "enable_extra_session_information": self._opts.enable_extra_session_information, + } + + headers = { + "Authorization": self._api_key, + "Content-Type": "application/json", + } + + ws_url = "wss://api.assemblyai.com/v2/realtime/ws" + filtered_config = {k: v for k, v in live_config.items() if v is not None} + url = f"{ws_url}?{urlencode(filtered_config).lower()}" + ws = await self._session.ws_connect(url, headers=headers) + return ws def _process_stream_event(self, data: dict, closing_ws: bool) -> None: # see this page: diff --git a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/version.py b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/version.py index 78ce264e0..2985d9da1 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/version.py +++ b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0" +__version__ = "0.2.2" diff --git a/livekit-plugins/livekit-plugins-assemblyai/package.json b/livekit-plugins/livekit-plugins-assemblyai/package.json index 5b509ec53..8b0962663 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/package.json +++ b/livekit-plugins/livekit-plugins-assemblyai/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-assemblyai", "private": true, - "version": "0.1.0" -} \ No newline at end of file + "version": "0.2.2" +} diff --git a/livekit-plugins/livekit-plugins-assemblyai/setup.py b/livekit-plugins/livekit-plugins-assemblyai/setup.py index 003b8876d..edd7e5494 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/setup.py +++ b/livekit-plugins/livekit-plugins-assemblyai/setup.py @@ -48,7 +48,7 @@ packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", install_requires=[ - "livekit-agents~=0.7", + "livekit-agents>=0.12.3", ], package_data={}, project_urls={ diff --git a/livekit-plugins/livekit-plugins-azure/CHANGELOG.md b/livekit-plugins/livekit-plugins-azure/CHANGELOG.md index 5597a247f..414181cbd 100644 --- a/livekit-plugins/livekit-plugins-azure/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-azure/CHANGELOG.md @@ -1,5 +1,39 @@ # livekit-plugins-azure +## 0.5.2 + +### Patch Changes + +- fix: Ensure STT exceptions are being propagated - [#1291](https://github.com/livekit/agents/pull/1291) ([@davidzhao](https://github.com/davidzhao)) + +## 0.5.1 + +### Patch Changes + +- fix azure stt language autodetection - [#1246](https://github.com/livekit/agents/pull/1246) ([@davidzhao](https://github.com/davidzhao)) + +## 0.5.0 + +### Minor Changes + +- Improvements to end of turn plugin, ensure STT language settings. - [#1195](https://github.com/livekit/agents/pull/1195) ([@davidzhao](https://github.com/davidzhao)) + +## 0.4.4 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.4.3 + +### Patch Changes + +- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom)) + +- azure: support auth entra token for TTS - [#1134](https://github.com/livekit/agents/pull/1134) ([@nfma](https://github.com/nfma)) + +- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom)) + ## 0.4.2 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py index 6452c7fbd..2bda776fd 100644 --- a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py +++ b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py @@ -13,16 +13,17 @@ from __future__ import annotations import asyncio +import contextlib import os +import weakref +from copy import deepcopy from dataclasses import dataclass from livekit import rtc -from livekit.agents import stt, utils +from livekit.agents import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions, stt, utils import azure.cognitiveservices.speech as speechsdk # type: ignore -from .log import logger - @dataclass class STTOptions: @@ -30,6 +31,8 @@ class STTOptions: speech_region: str | None # see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-container-stt?tabs=container#use-the-container speech_host: str | None + # for using Microsoft Entra auth (see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-configure-azure-ad-auth?tabs=portal&pivots=programming-language-python) + speech_auth_token: str | None sample_rate: int num_channels: int segmentation_silence_timeout_ms: int | None @@ -47,19 +50,25 @@ def __init__( speech_key: str | None = None, speech_region: str | None = None, speech_host: str | None = None, + speech_auth_token: str | None = None, sample_rate: int = 16000, num_channels: int = 1, segmentation_silence_timeout_ms: int | None = None, segmentation_max_time_ms: int | None = None, segmentation_strategy: str | None = None, - languages: list[str] = [], # when empty, auto-detect the language + # Azure handles multiple languages and can auto-detect the language used. It requires the candidate set to be set. + languages: list[str] = ["en-US"], + # for compatibility with other STT plugins + language: str | None = None, ): """ Create a new instance of Azure STT. - Either ``speech_host`` or ``speech_key`` and ``speech_region`` must be set, - either using arguments or by setting the ``AZURE_SPEECH_HOST``, ``AZURE_SPEECH_KEY`` + Either ``speech_host`` or ``speech_key`` and ``speech_region`` or + ``speech_auth_token`` and ``speech_region`` must be set using arguments. + Alternatively, set the ``AZURE_SPEECH_HOST``, ``AZURE_SPEECH_KEY`` and ``AZURE_SPEECH_REGION`` environmental variables, respectively. + ``speech_auth_token`` must be set using the arguments as it's an ephemeral token. """ super().__init__( @@ -69,15 +78,23 @@ def __init__( speech_key = speech_key or os.environ.get("AZURE_SPEECH_KEY") speech_region = speech_region or os.environ.get("AZURE_SPEECH_REGION") - if not speech_host and (not speech_key or not speech_region): + if not ( + speech_host + or (speech_key and speech_region) + or (speech_auth_token and speech_region) + ): raise ValueError( - "AZURE_SPEECH_HOST or AZURE_SPEECH_KEY and AZURE_SPEECH_REGION must be set" + "AZURE_SPEECH_HOST or AZURE_SPEECH_KEY and AZURE_SPEECH_REGION or speech_auth_token and AZURE_SPEECH_REGION must be set" ) + if language: + languages = [language] + self._config = STTOptions( speech_key=speech_key, speech_region=speech_region, speech_host=speech_host, + speech_auth_token=speech_auth_token, languages=languages, sample_rate=sample_rate, num_channels=num_channels, @@ -85,57 +102,127 @@ def __init__( segmentation_max_time_ms=segmentation_max_time_ms, segmentation_strategy=segmentation_strategy, ) + self._streams = weakref.WeakSet[SpeechStream]() async def _recognize_impl( - self, buffer: utils.AudioBuffer, *, language: str | None = None + self, + buffer: utils.AudioBuffer, + *, + language: str | None, + conn_options: APIConnectOptions, ) -> stt.SpeechEvent: raise NotImplementedError("Azure STT does not support single frame recognition") - def stream(self, *, language: str | None = None) -> "SpeechStream": - return SpeechStream(self, self._config) + def stream( + self, + *, + languages: list[str] | None = None, + language: str | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> "SpeechStream": + config = deepcopy(self._config) + if language and not languages: + languages = [language] + if languages: + config.languages = languages + stream = SpeechStream(stt=self, opts=config, conn_options=conn_options) + self._streams.add(stream) + return stream + + def update_options( + self, *, language: str | None = None, languages: list[str] | None = None + ): + if language and not languages: + languages = [language] + if languages is not None: + self._config.languages = languages + for stream in self._streams: + stream.update_options(languages=languages) class SpeechStream(stt.SpeechStream): - def __init__(self, stt: STT, opts: STTOptions) -> None: - super().__init__(stt, sample_rate=opts.sample_rate) + def __init__( + self, *, stt: STT, opts: STTOptions, conn_options: APIConnectOptions + ) -> None: + super().__init__( + stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate + ) self._opts = opts self._speaking = False - self._stream = speechsdk.audio.PushAudioInputStream( - stream_format=speechsdk.audio.AudioStreamFormat( - samples_per_second=self._opts.sample_rate, - bits_per_sample=16, - channels=self._opts.num_channels, - ) - ) - self._recognizer = _create_speech_recognizer( - config=self._opts, stream=self._stream - ) - self._recognizer.recognizing.connect(self._on_recognizing) - self._recognizer.recognized.connect(self._on_recognized) - self._recognizer.speech_start_detected.connect(self._on_speech_start) - self._recognizer.speech_end_detected.connect(self._on_speech_end) - self._recognizer.session_stopped.connect(self._on_session_stopped) - self._recognizer.start_continuous_recognition() - self._done_event = asyncio.Event() - self._loop = asyncio.get_running_loop() - - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - try: - async for input in self._input_ch: - if isinstance(input, rtc.AudioFrame): - self._stream.write(input.data.tobytes()) + self._session_stopped_event = asyncio.Event() + self._session_started_event = asyncio.Event() - self._stream.close() - await self._done_event.wait() - finally: - - def _cleanup(): - self._recognizer.stop_continuous_recognition() - del self._recognizer + self._loop = asyncio.get_running_loop() + self._reconnect_event = asyncio.Event() - await asyncio.to_thread(_cleanup) + def update_options( + self, *, language: str | None = None, languages: list[str] | None = None + ): + if language and not languages: + languages = [language] + if languages: + self._opts.languages = languages + self._reconnect_event.set() + + async def _run(self) -> None: + while True: + self._stream = speechsdk.audio.PushAudioInputStream( + stream_format=speechsdk.audio.AudioStreamFormat( + samples_per_second=self._opts.sample_rate, + bits_per_sample=16, + channels=self._opts.num_channels, + ) + ) + self._recognizer = _create_speech_recognizer( + config=self._opts, stream=self._stream + ) + self._recognizer.recognizing.connect(self._on_recognizing) + self._recognizer.recognized.connect(self._on_recognized) + self._recognizer.speech_start_detected.connect(self._on_speech_start) + self._recognizer.speech_end_detected.connect(self._on_speech_end) + self._recognizer.session_started.connect(self._on_session_started) + self._recognizer.session_stopped.connect(self._on_session_stopped) + self._recognizer.start_continuous_recognition() + + try: + await asyncio.wait_for( + self._session_started_event.wait(), self._conn_options.timeout + ) + + async def process_input(): + async for input in self._input_ch: + if isinstance(input, rtc.AudioFrame): + self._stream.write(input.data.tobytes()) + + process_input_task = asyncio.create_task(process_input()) + wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) + + try: + done, _ = await asyncio.wait( + [process_input_task, wait_reconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in done: + if task != wait_reconnect_task: + task.result() + finally: + await utils.aio.gracefully_cancel( + process_input_task, wait_reconnect_task + ) + + self._stream.close() + await self._session_stopped_event.wait() + finally: + + def _cleanup(): + self._recognizer.stop_continuous_recognition() + del self._recognizer + + await asyncio.to_thread(_cleanup) + if not self._reconnect_event.is_set(): + break + self._reconnect_event.clear() def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs): detected_lg = speechsdk.AutoDetectSourceLanguageResult(evt.result).language @@ -143,15 +230,20 @@ def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs): if not text: return + if not detected_lg and self._opts.languages: + detected_lg = self._opts.languages[0] + final_data = stt.SpeechData( language=detected_lg, confidence=1.0, text=evt.result.text ) - self._threadsafe_send( - stt.SpeechEvent( - type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[final_data] + with contextlib.suppress(RuntimeError): + self._loop.call_soon_threadsafe( + self._event_ch.send_nowait, + stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[final_data] + ), ) - ) def _on_recognizing(self, evt: speechsdk.SpeechRecognitionEventArgs): detected_lg = speechsdk.AutoDetectSourceLanguageResult(evt.result).language @@ -159,35 +251,55 @@ def _on_recognizing(self, evt: speechsdk.SpeechRecognitionEventArgs): if not text: return + if not detected_lg and self._opts.languages: + detected_lg = self._opts.languages[0] + interim_data = stt.SpeechData( language=detected_lg, confidence=0.0, text=evt.result.text ) - self._threadsafe_send( - stt.SpeechEvent( - type=stt.SpeechEventType.INTERIM_TRANSCRIPT, alternatives=[interim_data] + with contextlib.suppress(RuntimeError): + self._loop.call_soon_threadsafe( + self._event_ch.send_nowait, + stt.SpeechEvent( + type=stt.SpeechEventType.INTERIM_TRANSCRIPT, + alternatives=[interim_data], + ), ) - ) def _on_speech_start(self, evt: speechsdk.SpeechRecognitionEventArgs): if self._speaking: return self._speaking = True - self._threadsafe_send(stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)) + + with contextlib.suppress(RuntimeError): + self._loop.call_soon_threadsafe( + self._event_ch.send_nowait, + stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH), + ) def _on_speech_end(self, evt: speechsdk.SpeechRecognitionEventArgs): if not self._speaking: return self._speaking = False - self._threadsafe_send(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)) - def _on_session_stopped(self, evt: speechsdk.SpeechRecognitionEventArgs): - self._loop.call_soon_threadsafe(self._done_event.set) + with contextlib.suppress(RuntimeError): + self._loop.call_soon_threadsafe( + self._event_ch.send_nowait, + stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH), + ) + + def _on_session_started(self, evt: speechsdk.SpeechRecognitionEventArgs): + self._session_started_event.set() - def _threadsafe_send(self, evt: stt.SpeechEvent): - self._loop.call_soon_threadsafe(self._event_ch.send_nowait, evt) + with contextlib.suppress(RuntimeError): + self._loop.call_soon_threadsafe(self._session_started_event.set) + + def _on_session_stopped(self, evt: speechsdk.SpeechRecognitionEventArgs): + with contextlib.suppress(RuntimeError): + self._loop.call_soon_threadsafe(self._session_stopped_event.set) def _create_speech_recognizer( @@ -195,6 +307,10 @@ def _create_speech_recognizer( ) -> speechsdk.SpeechRecognizer: if config.speech_host: speech_config = speechsdk.SpeechConfig(host=config.speech_host) + if config.speech_auth_token: + speech_config = speechsdk.SpeechConfig( + auth_token=config.speech_auth_token, region=config.speech_region + ) else: speech_config = speechsdk.SpeechConfig( subscription=config.speech_key, region=config.speech_region @@ -217,7 +333,7 @@ def _create_speech_recognizer( ) auto_detect_source_language_config = None - if config.languages: + if config.languages and len(config.languages) >= 1: auto_detect_source_language_config = ( speechsdk.languageconfig.AutoDetectSourceLanguageConfig( languages=config.languages diff --git a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py index 979274bd9..155d2c091 100644 --- a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py +++ b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py @@ -114,6 +114,8 @@ class _TTSOptions: voice: str | None = None # for using custom voices (see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-speech-synthesis?tabs=browserjs%2Cterminal&pivots=programming-language-python#use-a-custom-endpoint) endpoint_id: str | None = None + # for using Microsoft Entra auth (see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-configure-azure-ad-auth?tabs=portal&pivots=programming-language-python) + speech_auth_token: str | None = None # Useful to specify the language with multi-language voices language: str | None = None # See https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-voice#adjust-prosody @@ -131,13 +133,17 @@ def __init__( speech_key: str | None = None, speech_region: str | None = None, speech_host: str | None = None, + speech_auth_token: str | None = None, endpoint_id: str | None = None, ) -> None: """ Create a new instance of Azure TTS. - ``speech_key`` and ``speech_region`` must be set, either using arguments or by setting the - ``AZURE_SPEECH_KEY`` and ``AZURE_SPEECH_REGION`` environmental variables, respectively. + Either ``speech_host`` or ``speech_key`` and ``speech_region`` or + ``speech_auth_token`` and ``speech_region`` must be set using arguments. + Alternatively, set the ``AZURE_SPEECH_HOST``, ``AZURE_SPEECH_KEY`` + and ``AZURE_SPEECH_REGION`` environmental variables, respectively. + ``speech_auth_token`` must be set using the arguments as it's an ephemeral token. """ if sample_rate not in SUPPORTED_SAMPLE_RATE: @@ -157,9 +163,13 @@ def __init__( speech_key = speech_key or os.environ.get("AZURE_SPEECH_KEY") speech_region = speech_region or os.environ.get("AZURE_SPEECH_REGION") - if not speech_host and not (speech_key and speech_region): + if not ( + speech_host + or (speech_key and speech_region) + or (speech_auth_token and speech_region) + ): raise ValueError( - "AZURE_SPEECH_HOST or AZURE_SPEECH_KEY and AZURE_SPEECH_REGION must be set" + "AZURE_SPEECH_HOST or AZURE_SPEECH_KEY and AZURE_SPEECH_REGION or speech_auth_token and AZURE_SPEECH_REGION must be set" ) if prosody: @@ -169,6 +179,8 @@ def __init__( sample_rate=sample_rate, speech_key=speech_key, speech_region=speech_region, + speech_host=speech_host, + speech_auth_token=speech_auth_token, voice=voice, endpoint_id=endpoint_id, language=language, @@ -314,6 +326,12 @@ def _create_speech_synthesizer( ) -> speechsdk.SpeechSynthesizer: if config.speech_host: speech_config = speechsdk.SpeechConfig(host=config.speech_host) + if config.speech_auth_token: + speech_config = speechsdk.SpeechConfig( + auth_token=config.speech_auth_token, + region=config.speech_region, + speech_recognition_language=config.language or "en-US", + ) else: speech_config = speechsdk.SpeechConfig( subscription=config.speech_key, diff --git a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/version.py b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/version.py index 2b3617a7b..ec65e487a 100644 --- a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/version.py +++ b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.4.2" +__version__ = "0.5.2" diff --git a/livekit-plugins/livekit-plugins-azure/package.json b/livekit-plugins/livekit-plugins-azure/package.json index 0541e5a1c..45561032c 100644 --- a/livekit-plugins/livekit-plugins-azure/package.json +++ b/livekit-plugins/livekit-plugins-azure/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-azure", "private": true, - "version": "0.4.2" + "version": "0.5.2" } diff --git a/livekit-plugins/livekit-plugins-azure/setup.py b/livekit-plugins/livekit-plugins-azure/setup.py index 288de7187..e854fc492 100644 --- a/livekit-plugins/livekit-plugins-azure/setup.py +++ b/livekit-plugins/livekit-plugins-azure/setup.py @@ -46,7 +46,7 @@ packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", install_requires=[ - "livekit-agents>=0.11", + "livekit-agents>=0.12.3", "azure-cognitiveservices-speech>=1.41.0", ], package_data={}, diff --git a/livekit-plugins/livekit-plugins-browser/CHANGELOG.md b/livekit-plugins/livekit-plugins-browser/CHANGELOG.md index 24b32b191..498a259c3 100644 --- a/livekit-plugins/livekit-plugins-browser/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-browser/CHANGELOG.md @@ -1,5 +1,17 @@ # livekit-plugins-browser +## 0.0.5 + +### Patch Changes + +- fix: fix `imgui` setup - [#1226](https://github.com/livekit/agents/pull/1226) ([@mbukeRepo](https://github.com/mbukeRepo)) + +## 0.0.4 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + ## 0.0.3 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-browser/livekit/plugins/browser/version.py b/livekit-plugins/livekit-plugins-browser/livekit/plugins/browser/version.py index 64214b2f9..0f8366140 100644 --- a/livekit-plugins/livekit-plugins-browser/livekit/plugins/browser/version.py +++ b/livekit-plugins/livekit-plugins-browser/livekit/plugins/browser/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.3" +__version__ = "0.0.5" diff --git a/livekit-plugins/livekit-plugins-browser/package.json b/livekit-plugins/livekit-plugins-browser/package.json index 672cb89c0..f28e403c5 100644 --- a/livekit-plugins/livekit-plugins-browser/package.json +++ b/livekit-plugins/livekit-plugins-browser/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-browser", "private": true, - "version": "0.0.3" + "version": "0.0.5" } diff --git a/livekit-plugins/livekit-plugins-browser/setup.py b/livekit-plugins/livekit-plugins-browser/setup.py index 8eafd27d8..088259ebf 100644 --- a/livekit-plugins/livekit-plugins-browser/setup.py +++ b/livekit-plugins/livekit-plugins-browser/setup.py @@ -113,7 +113,7 @@ def build_extension(self, ext: CMakeExtension) -> None: cmdclass={"build_ext": CMakeBuild}, packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents>=0.11"], + install_requires=["livekit-agents>=0.12.3"], package_data={ "livekit.plugins.browser": ["py.typed"], "livekit.plugins.browser.resources": ["**", "lkcef_app.app"], diff --git a/livekit-plugins/livekit-plugins-browser/src/CMakeLists.txt b/livekit-plugins/livekit-plugins-browser/src/CMakeLists.txt index 298ee3c37..f236519cb 100644 --- a/livekit-plugins/livekit-plugins-browser/src/CMakeLists.txt +++ b/livekit-plugins/livekit-plugins-browser/src/CMakeLists.txt @@ -11,8 +11,15 @@ set(GLFW_INSTALL OFF CACHE BOOL "" FORCE) FetchContent_Declare(glfw GIT_REPOSITORY https://github.com/glfw/glfw.git GIT_TAG 3.4) FetchContent_MakeAvailable(glfw) -FetchContent_Declare(imgui GIT_REPOSITORY https://github.com/ocornut/imgui GIT_TAG origin/docking) +FetchContent_Declare( + imgui + GIT_REPOSITORY https://github.com/ocornut/imgui + GIT_TAG origin/docking + GIT_SHALLOW TRUE +) FetchContent_GetProperties(imgui) +FetchContent_Populate(imgui) + FetchContent_MakeAvailable(imgui) file(GLOB IMGUI_SOURCES ${imgui_SOURCE_DIR}/*.cpp) add_library(imgui STATIC ${IMGUI_SOURCES} diff --git a/livekit-plugins/livekit-plugins-cartesia/CHANGELOG.md b/livekit-plugins/livekit-plugins-cartesia/CHANGELOG.md index 508aa80ac..c949ba7bb 100644 --- a/livekit-plugins/livekit-plugins-cartesia/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-cartesia/CHANGELOG.md @@ -1,5 +1,17 @@ # livekit-plugins-cartesia +## 0.4.5 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.4.4 + +### Patch Changes + +- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom)) + ## 0.4.3 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py index dd76473c7..eae3a0679 100644 --- a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py +++ b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py @@ -312,7 +312,10 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, ): - raise Exception("Cartesia connection closed unexpectedly") + raise APIStatusError( + "Cartesia connection closed unexpectedly", + request_id=request_id, + ) if msg.type != aiohttp.WSMsgType.TEXT: logger.warning("unexpected Cartesia message type %s", msg.type) diff --git a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/version.py b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/version.py index 728ebaff3..6667b2426 100644 --- a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/version.py +++ b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.4.3" +__version__ = "0.4.5" diff --git a/livekit-plugins/livekit-plugins-cartesia/package.json b/livekit-plugins/livekit-plugins-cartesia/package.json index f292f6ff0..f87b43bd4 100644 --- a/livekit-plugins/livekit-plugins-cartesia/package.json +++ b/livekit-plugins/livekit-plugins-cartesia/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-cartesia", "private": true, - "version": "0.4.3" + "version": "0.4.5" } diff --git a/livekit-plugins/livekit-plugins-cartesia/setup.py b/livekit-plugins/livekit-plugins-cartesia/setup.py index e4ce007f9..8044f23c6 100644 --- a/livekit-plugins/livekit-plugins-cartesia/setup.py +++ b/livekit-plugins/livekit-plugins-cartesia/setup.py @@ -47,7 +47,7 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents>=0.11"], + install_requires=["livekit-agents>=0.12.3"], project_urls={ "Documentation": "https://docs.livekit.io", "Website": "https://livekit.io/", diff --git a/livekit-plugins/livekit-plugins-clova/livekit/plugins/clova/stt.py b/livekit-plugins/livekit-plugins-clova/livekit/plugins/clova/stt.py index ef7367c77..a98222299 100644 --- a/livekit-plugins/livekit-plugins-clova/livekit/plugins/clova/stt.py +++ b/livekit-plugins/livekit-plugins-clova/livekit/plugins/clova/stt.py @@ -24,6 +24,7 @@ import aiohttp from livekit.agents import ( + APIConnectOptions, APIStatusError, APITimeoutError, stt, @@ -68,6 +69,11 @@ def __init__( ) self.threshold = threshold + def update_options(self, *, language: str | None = None) -> None: + self._language = ( + clova_languages_mapping.get(language, language) or self._language + ) + def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: self._session = utils.http_context.http_session() @@ -80,9 +86,10 @@ def url_builder( async def _recognize_impl( self, - *, buffer: AudioBuffer, - language: Union[ClovaSttLanguages, str, None] = None, + *, + language: Union[ClovaSttLanguages, str, None], + conn_options: APIConnectOptions, ) -> stt.SpeechEvent: try: url = self.url_builder() @@ -109,7 +116,13 @@ async def _recognize_impl( ) start = time.time() async with self._ensure_session().post( - url, data=form_data, headers=headers + url, + data=form_data, + headers=headers, + timeout=aiohttp.ClientTimeout( + total=30, + sock_connect=conn_options.timeout, + ), ) as response: response_data = await response.json() end = time.time() diff --git a/livekit-plugins/livekit-plugins-clova/setup.py b/livekit-plugins/livekit-plugins-clova/setup.py index 254fd1cba..08abcf970 100644 --- a/livekit-plugins/livekit-plugins-clova/setup.py +++ b/livekit-plugins/livekit-plugins-clova/setup.py @@ -47,7 +47,7 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents>=0.11", "pydub~=0.25.1"], + install_requires=["livekit-agents>=0.12.3", "pydub~=0.25.1"], project_urls={ "Documentation": "https://docs.livekit.io", "Website": "https://livekit.io/", diff --git a/livekit-plugins/livekit-plugins-deepgram/CHANGELOG.md b/livekit-plugins/livekit-plugins-deepgram/CHANGELOG.md index 19d0812a4..617d61f38 100644 --- a/livekit-plugins/livekit-plugins-deepgram/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-deepgram/CHANGELOG.md @@ -1,5 +1,39 @@ # livekit-plugins-deepgram +## 0.6.16 + +### Patch Changes + +- fix: Ensure STT exceptions are being propagated - [#1291](https://github.com/livekit/agents/pull/1291) ([@davidzhao](https://github.com/davidzhao)) + +## 0.6.15 + +### Patch Changes + +- added streaming audio decoder for compressed audio. - [#1236](https://github.com/livekit/agents/pull/1236) ([@davidzhao](https://github.com/davidzhao)) + +- Support Deepgram TTS - [#1201](https://github.com/livekit/agents/pull/1201) ([@jayeshp19](https://github.com/jayeshp19)) + +## 0.6.14 + +### Patch Changes + +- enable deepgram filler words by default to improve end of turn accuracy - [#1190](https://github.com/livekit/agents/pull/1190) ([@davidzhao](https://github.com/davidzhao)) + +## 0.6.13 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.6.12 + +### Patch Changes + +- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom)) + +- Added support for custom deepgram base url - [#1137](https://github.com/livekit/agents/pull/1137) ([@theomonnom](https://github.com/theomonnom)) + ## 0.6.11 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/__init__.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/__init__.py index 6c93b5276..dcfbb04ad 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/__init__.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/__init__.py @@ -1,7 +1,8 @@ from .stt import STT, AudioEnergyFilter, SpeechStream +from .tts import TTS from .version import __version__ -__all__ = ["STT", "SpeechStream", "AudioEnergyFilter", "__version__"] +__all__ = ["STT", "SpeechStream", "AudioEnergyFilter", "__version__", "TTS"] from livekit.agents import Plugin diff --git a/livekit-agents/livekit/agents/metrics/periodic_collector.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/_utils.py similarity index 100% rename from livekit-agents/livekit/agents/metrics/periodic_collector.py rename to livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/_utils.py diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 34970d01c..d45966e4e 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -16,10 +16,9 @@ import asyncio import dataclasses -import io import json import os -import wave +import weakref from dataclasses import dataclass from enum import Enum from typing import List, Optional, Tuple @@ -29,20 +28,21 @@ import numpy as np from livekit import rtc from livekit.agents import ( + DEFAULT_API_CONNECT_OPTIONS, APIConnectionError, + APIConnectOptions, APIStatusError, APITimeoutError, - metrics, stt, utils, ) -from livekit.agents.utils import AudioBuffer, merge_frames +from livekit.agents.utils import AudioBuffer +from ._utils import PeriodicCollector from .log import logger from .models import DeepgramLanguages, DeepgramModels BASE_URL = "https://api.deepgram.com/v1/listen" -BASE_URL_WS = "wss://api.deepgram.com/v1/listen" # This is the magic number during testing that we use to determine if a frame is loud enough @@ -121,11 +121,13 @@ def __init__( sample_rate: int = 16000, no_delay: bool = True, endpointing_ms: int = 25, - filler_words: bool = False, + # enable filler words by default to improve turn detector accuracy + filler_words: bool = True, keywords: list[Tuple[str, float]] = [], profanity_filter: bool = False, api_key: str | None = None, http_session: aiohttp.ClientSession | None = None, + base_url: str = BASE_URL, energy_filter: AudioEnergyFilter | bool = False, ) -> None: """ @@ -140,26 +142,13 @@ def __init__( streaming=True, interim_results=interim_results ) ) + self._base_url = base_url api_key = api_key or os.environ.get("DEEPGRAM_API_KEY") if api_key is None: raise ValueError("Deepgram API key is required") - if language not in ("en-US", "en") and model in ( - "nova-2-meeting", - "nova-2-phonecall", - "nova-2-finance", - "nova-2-conversationalai", - "nova-2-voicemail", - "nova-2-video", - "nova-2-medical", - "nova-2-drivethru", - "nova-2-automotive", - ): - logger.warning( - f"{model} does not support language {language}, falling back to nova-2-general" - ) - model = "nova-2-general" + model = _validate_model(model, language) self._api_key = api_key @@ -180,6 +169,7 @@ def __init__( energy_filter=energy_filter, ) self._session = http_session + self._streams = weakref.WeakSet[SpeechStream]() def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: @@ -188,7 +178,11 @@ def _ensure_session(self) -> aiohttp.ClientSession: return self._session async def _recognize_impl( - self, buffer: AudioBuffer, *, language: DeepgramLanguages | str | None = None + self, + buffer: AudioBuffer, + *, + language: DeepgramLanguages | str | None, + conn_options: APIConnectOptions, ) -> stt.SpeechEvent: config = self._sanitize_options(language=language) @@ -203,25 +197,19 @@ async def _recognize_impl( if config.language: recognize_config["language"] = config.language - buffer = merge_frames(buffer) - io_buffer = io.BytesIO() - with wave.open(io_buffer, "wb") as wav: - wav.setnchannels(buffer.num_channels) - wav.setsampwidth(2) # 16-bit - wav.setframerate(buffer.sample_rate) - wav.writeframes(buffer.data) - - data = io_buffer.getvalue() - try: async with self._ensure_session().post( - url=_to_deepgram_url(recognize_config), - data=data, + url=_to_deepgram_url(recognize_config, self._base_url, websocket=False), + data=rtc.combine_audio_frames(buffer).to_wav_bytes(), headers={ "Authorization": f"Token {self._api_key}", "Accept": "application/json", "Content-Type": "audio/wav", }, + timeout=aiohttp.ClientTimeout( + total=30, + sock_connect=conn_options.timeout, + ), ) as res: return prerecorded_transcription_to_speech_event( config.language, @@ -241,10 +229,75 @@ async def _recognize_impl( raise APIConnectionError() from e def stream( - self, *, language: DeepgramLanguages | str | None = None + self, + *, + language: DeepgramLanguages | str | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> "SpeechStream": config = self._sanitize_options(language=language) - return SpeechStream(self, config, self._api_key, self._ensure_session()) + stream = SpeechStream( + stt=self, + conn_options=conn_options, + opts=config, + api_key=self._api_key, + http_session=self._ensure_session(), + base_url=self._base_url, + ) + self._streams.add(stream) + return stream + + def update_options( + self, + *, + language: DeepgramLanguages | None = None, + model: DeepgramModels | None = None, + interim_results: bool | None = None, + punctuate: bool | None = None, + smart_format: bool | None = None, + sample_rate: int | None = None, + no_delay: bool | None = None, + endpointing_ms: int | None = None, + filler_words: bool | None = None, + keywords: list[Tuple[str, float]] | None = None, + profanity_filter: bool | None = None, + ): + if language is not None: + self._opts.language = language + if model is not None: + self._opts.model = _validate_model(model, language) + if interim_results is not None: + self._opts.interim_results = interim_results + if punctuate is not None: + self._opts.punctuate = punctuate + if smart_format is not None: + self._opts.smart_format = smart_format + if sample_rate is not None: + self._opts.sample_rate = sample_rate + if no_delay is not None: + self._opts.no_delay = no_delay + if endpointing_ms is not None: + self._opts.endpointing_ms = endpointing_ms + if filler_words is not None: + self._opts.filler_words = filler_words + if keywords is not None: + self._opts.keywords = keywords + if profanity_filter is not None: + self._opts.profanity_filter = profanity_filter + + for stream in self._streams: + stream.update_options( + language=language, + model=model, + interim_results=interim_results, + punctuate=punctuate, + smart_format=smart_format, + sample_rate=sample_rate, + no_delay=no_delay, + endpointing_ms=endpointing_ms, + filler_words=filler_words, + keywords=keywords, + profanity_filter=profanity_filter, + ) def _sanitize_options(self, *, language: str | None = None) -> STTOptions: config = dataclasses.replace(self._opts) @@ -263,13 +316,17 @@ class SpeechStream(stt.SpeechStream): def __init__( self, + *, stt: STT, opts: STTOptions, + conn_options: APIConnectOptions, api_key: str, http_session: aiohttp.ClientSession, - max_retry: int = 32, + base_url: str, ) -> None: - super().__init__(stt, sample_rate=opts.sample_rate) + super().__init__( + stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate + ) if opts.detect_language and opts.language is None: raise ValueError("language detection is not supported in streaming mode") @@ -277,9 +334,9 @@ def __init__( self._opts = opts self._api_key = api_key self._session = http_session + self._base_url = base_url self._speaking = False - self._max_retry = max_retry - self._audio_duration_collector = metrics.PeriodicCollector( + self._audio_duration_collector = PeriodicCollector( callback=self._on_audio_duration_report, duration=5.0, ) @@ -293,73 +350,52 @@ def __init__( self._pushed_audio_duration = 0.0 self._request_id = "" + self._reconnect_event = asyncio.Event() - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - await self._run(self._max_retry) - - async def _run(self, max_retry: int) -> None: - """ - Run a single websocket connection to Deepgram and make sure to reconnect - when something went wrong. - """ - - retry_count = 0 - while self._input_ch.qsize() or not self._input_ch.closed: - try: - live_config = { - "model": self._opts.model, - "punctuate": self._opts.punctuate, - "smart_format": self._opts.smart_format, - "no_delay": self._opts.no_delay, - "interim_results": self._opts.interim_results, - "encoding": "linear16", - "vad_events": True, - "sample_rate": self._opts.sample_rate, - "channels": self._opts.num_channels, - "endpointing": False - if self._opts.endpointing_ms == 0 - else self._opts.endpointing_ms, - "filler_words": self._opts.filler_words, - "keywords": self._opts.keywords, - "profanity_filter": self._opts.profanity_filter, - } - - if self._opts.language: - live_config["language"] = self._opts.language - - headers = {"Authorization": f"Token {self._api_key}"} - ws = await self._session.ws_connect( - _to_deepgram_url(live_config, websocket=True), headers=headers - ) - retry_count = 0 # connected successfully, reset the retry_count - - await self._run_ws(ws) - except Exception as e: - if self._session.closed: - break - - if retry_count >= max_retry: - logger.exception( - f"failed to connect to deepgram after {max_retry} tries" - ) - break - - retry_delay = min(retry_count * 2, 10) # max 10s - retry_count += 1 # increment after calculating the delay, the first retry should happen directly - - logger.warning( - f"deepgram connection failed, retrying in {retry_delay}s", - exc_info=e, - ) - await asyncio.sleep(retry_delay) - - async def _run_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None: - """This method could throw ws errors, these are handled inside the _run method""" - + def update_options( + self, + *, + language: DeepgramLanguages | None = None, + model: DeepgramModels | None = None, + interim_results: bool | None = None, + punctuate: bool | None = None, + smart_format: bool | None = None, + sample_rate: int | None = None, + no_delay: bool | None = None, + endpointing_ms: int | None = None, + filler_words: bool | None = None, + keywords: list[Tuple[str, float]] | None = None, + profanity_filter: bool | None = None, + ): + if language is not None: + self._opts.language = language + if model is not None: + self._opts.model = _validate_model(model, language) + if interim_results is not None: + self._opts.interim_results = interim_results + if punctuate is not None: + self._opts.punctuate = punctuate + if smart_format is not None: + self._opts.smart_format = smart_format + if sample_rate is not None: + self._opts.sample_rate = sample_rate + if no_delay is not None: + self._opts.no_delay = no_delay + if endpointing_ms is not None: + self._opts.endpointing_ms = endpointing_ms + if filler_words is not None: + self._opts.filler_words = filler_words + if keywords is not None: + self._opts.keywords = keywords + if profanity_filter is not None: + self._opts.profanity_filter = profanity_filter + + self._reconnect_event.set() + + async def _run(self) -> None: closing_ws = False - async def keepalive_task(): + async def keepalive_task(ws: aiohttp.ClientWebSocketResponse): # if we want to keep the connection alive even if no audio is sent, # Deepgram expects a keepalive message. # https://developers.deepgram.com/reference/listen-live#stream-keepalive @@ -370,7 +406,7 @@ async def keepalive_task(): except Exception: return - async def send_task(): + async def send_task(ws: aiohttp.ClientWebSocketResponse): nonlocal closing_ws # forward audio to deepgram in chunks of 50ms @@ -422,7 +458,7 @@ async def send_task(): closing_ws = True await ws.send_str(SpeechStream._CLOSE_MSG) - async def recv_task(): + async def recv_task(ws: aiohttp.ClientWebSocketResponse): nonlocal closing_ws while True: msg = await ws.receive() @@ -435,7 +471,9 @@ async def recv_task(): return # this will trigger a reconnection, see the _run loop - raise Exception("deepgram connection closed unexpectedly") + raise APIStatusError( + message="deepgram connection closed unexpectedly" + ) if msg.type != aiohttp.WSMsgType.TEXT: logger.warning("unexpected deepgram message type %s", msg.type) @@ -446,16 +484,68 @@ async def recv_task(): except Exception: logger.exception("failed to process deepgram message") - tasks = [ - asyncio.create_task(send_task()), - asyncio.create_task(recv_task()), - asyncio.create_task(keepalive_task()), - ] + ws: aiohttp.ClientWebSocketResponse | None = None - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) + while True: + try: + ws = await self._connect_ws() + tasks = [ + asyncio.create_task(send_task(ws)), + asyncio.create_task(recv_task(ws)), + asyncio.create_task(keepalive_task(ws)), + ] + wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) + try: + done, _ = await asyncio.wait( + [asyncio.gather(*tasks), wait_reconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) # type: ignore + + # propagate exceptions from completed tasks + for task in done: + if task != wait_reconnect_task: + task.result() + + if wait_reconnect_task not in done: + break + + self._reconnect_event.clear() + finally: + await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task) + finally: + if ws is not None: + await ws.close() + + async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: + live_config = { + "model": self._opts.model, + "punctuate": self._opts.punctuate, + "smart_format": self._opts.smart_format, + "no_delay": self._opts.no_delay, + "interim_results": self._opts.interim_results, + "encoding": "linear16", + "vad_events": True, + "sample_rate": self._opts.sample_rate, + "channels": self._opts.num_channels, + "endpointing": False + if self._opts.endpointing_ms == 0 + else self._opts.endpointing_ms, + "filler_words": self._opts.filler_words, + "keywords": self._opts.keywords, + "profanity_filter": self._opts.profanity_filter, + } + + if self._opts.language: + live_config["language"] = self._opts.language + + ws = await asyncio.wait_for( + self._session.ws_connect( + _to_deepgram_url(live_config, base_url=self._base_url, websocket=True), + headers={"Authorization": f"Token {self._api_key}"}, + ), + self._conn_options.timeout, + ) + return ws def _check_energy_state(self, frame: rtc.AudioFrame) -> AudioEnergyFilter.State: if self._audio_energy_filter: @@ -584,7 +674,7 @@ def prerecorded_transcription_to_speech_event( ) -def _to_deepgram_url(opts: dict, *, websocket: bool = False) -> str: +def _to_deepgram_url(opts: dict, base_url: str, *, websocket: bool) -> str: if opts.get("keywords"): # convert keywords to a list of "keyword:intensifier" opts["keywords"] = [ @@ -593,5 +683,33 @@ def _to_deepgram_url(opts: dict, *, websocket: bool = False) -> str: # lowercase bools opts = {k: str(v).lower() if isinstance(v, bool) else v for k, v in opts.items()} - base_url = BASE_URL_WS if websocket else BASE_URL + + if websocket and base_url.startswith("http"): + base_url = base_url.replace("http", "ws", 1) + + elif not websocket and base_url.startswith("ws"): + base_url = base_url.replace("ws", "http", 1) + return f"{base_url}?{urlencode(opts, doseq=True)}" + + +def _validate_model( + model: DeepgramModels, language: DeepgramLanguages | str | None +) -> DeepgramModels: + en_only_models = { + "nova-2-meeting", + "nova-2-phonecall", + "nova-2-finance", + "nova-2-conversationalai", + "nova-2-voicemail", + "nova-2-video", + "nova-2-medical", + "nova-2-drivethru", + "nova-2-automotive", + } + if language not in ("en-US", "en") and model in en_only_models: + logger.warning( + f"{model} does not support language {language}, falling back to nova-2-general" + ) + return "nova-2-general" + return model diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py new file mode 100644 index 000000000..401c26be7 --- /dev/null +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py @@ -0,0 +1,421 @@ +from __future__ import annotations + +import asyncio +import json +import os +import weakref +from dataclasses import dataclass +from urllib.parse import urlencode + +import aiohttp +from livekit import rtc +from livekit.agents import ( + DEFAULT_API_CONNECT_OPTIONS, + APIConnectionError, + APIConnectOptions, + APIStatusError, + APITimeoutError, + tokenize, + tts, + utils, +) + +from .log import logger + +BASE_URL = "https://api.deepgram.com/v1/speak" +NUM_CHANNELS = 1 + + +@dataclass +class _TTSOptions: + model: str + encoding: str + sample_rate: int + word_tokenizer: tokenize.WordTokenizer + + +class TTS(tts.TTS): + def __init__( + self, + *, + model: str = "aura-asteria-en", + encoding: str = "linear16", + sample_rate: int = 24000, + api_key: str | None = None, + base_url: str = BASE_URL, + word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer( + ignore_punctuation=False + ), + http_session: aiohttp.ClientSession | None = None, + ) -> None: + """ + Create a new instance of Deepgram TTS. + + Args: + model (str): TTS model to use. Defaults to "aura-asteria-en". + encoding (str): Audio encoding to use. Defaults to "linear16". + sample_rate (int): Sample rate of audio. Defaults to 24000. + api_key (str): Deepgram API key. If not provided, will look for DEEPGRAM_API_KEY in environment. + base_url (str): Base URL for Deepgram TTS API. Defaults to "https://api.deepgram.com/v1/speak" + word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer. + http_session (aiohttp.ClientSession): Optional aiohttp session to use for requests. + + """ + super().__init__( + capabilities=tts.TTSCapabilities(streaming=True), + sample_rate=sample_rate, + num_channels=NUM_CHANNELS, + ) + + api_key = api_key or os.environ.get("DEEPGRAM_API_KEY") + if not api_key: + raise ValueError( + "Deepgram API key required. Set DEEPGRAM_API_KEY or provide api_key." + ) + + self._opts = _TTSOptions( + model=model, + encoding=encoding, + sample_rate=sample_rate, + word_tokenizer=word_tokenizer, + ) + self._session = http_session + self._api_key = api_key + self._base_url = base_url + self._streams = weakref.WeakSet[SynthesizeStream]() + + def _ensure_session(self) -> aiohttp.ClientSession: + if not self._session: + self._session = utils.http_context.http_session() + return self._session + + def update_options( + self, + *, + model: str | None = None, + sample_rate: int | None = None, + ) -> None: + """ + args: + model (str): TTS model to use. + sample_rate (int): Sample rate of audio. + """ + if model is not None: + self._opts.model = model + if sample_rate is not None: + self._opts.sample_rate = sample_rate + for stream in self._streams: + stream.update_options( + model=model, + sample_rate=sample_rate, + ) + + def synthesize( + self, + text: str, + *, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> "ChunkedStream": + return ChunkedStream( + tts=self, + input_text=text, + base_url=self._base_url, + api_key=self._api_key, + conn_options=conn_options, + opts=self._opts, + session=self._ensure_session(), + ) + + def stream( + self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS + ) -> "SynthesizeStream": + stream = SynthesizeStream( + tts=self, + conn_options=conn_options, + base_url=self._base_url, + api_key=self._api_key, + opts=self._opts, + session=self._ensure_session(), + ) + self._streams.add(stream) + return stream + + +class ChunkedStream(tts.ChunkedStream): + def __init__( + self, + *, + tts: TTS, + base_url: str, + api_key: str, + input_text: str, + opts: _TTSOptions, + conn_options: APIConnectOptions, + session: aiohttp.ClientSession, + ) -> None: + super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) + self._opts = opts + self._session = session + self._base_url = base_url + self._api_key = api_key + + async def _run(self) -> None: + request_id = utils.shortuuid() + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=NUM_CHANNELS, + ) + + try: + config = { + "encoding": self._opts.encoding, + "model": self._opts.model, + "sample_rate": self._opts.sample_rate, + } + async with self._session.post( + _to_deepgram_url(config, self._base_url, websocket=False), + headers={ + "Authorization": f"Token {self._api_key}", + "Content-Type": "application/json", + }, + json={"text": self._input_text}, + timeout=self._conn_options.timeout, + ) as res: + if res.status != 200: + raise APIStatusError( + message=res.reason or "Unknown error occurred.", + status_code=res.status, + request_id=request_id, + body=await res.json(), + ) + + async for bytes_data, _ in res.content.iter_chunks(): + for frame in audio_bstream.write(bytes_data): + self._event_ch.send_nowait( + tts.SynthesizedAudio( + request_id=request_id, + frame=frame, + ) + ) + + for frame in audio_bstream.flush(): + self._event_ch.send_nowait( + tts.SynthesizedAudio(request_id=request_id, frame=frame) + ) + + except asyncio.TimeoutError as e: + raise APITimeoutError() from e + except aiohttp.ClientResponseError as e: + raise APIStatusError( + message=e.message, + status_code=e.status, + request_id=request_id, + body=None, + ) from e + except Exception as e: + raise APIConnectionError() from e + + +class SynthesizeStream(tts.SynthesizeStream): + def __init__( + self, + *, + tts: TTS, + base_url: str, + api_key: str, + conn_options: APIConnectOptions, + opts: _TTSOptions, + session: aiohttp.ClientSession, + ): + super().__init__(tts=tts, conn_options=conn_options) + self._opts = opts + self._session = session + self._base_url = base_url + self._api_key = api_key + self._segments_ch = utils.aio.Chan[tokenize.WordStream]() + self._reconnect_event = asyncio.Event() + + def update_options( + self, + *, + model: str | None = None, + sample_rate: int | None = None, + ) -> None: + if model is not None: + self._opts.model = model + if sample_rate is not None: + self._opts.sample_rate = sample_rate + + self._reconnect_event.set() + + async def _run(self) -> None: + closing_ws = False + request_id = utils.shortuuid() + segment_id = utils.shortuuid() + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=NUM_CHANNELS, + ) + + @utils.log_exceptions(logger=logger) + async def _tokenize_input(): + # Converts incoming text into WordStreams and sends them into _segments_ch + word_stream = None + async for input in self._input_ch: + if isinstance(input, str): + if word_stream is None: + word_stream = self._opts.word_tokenizer.stream() + self._segments_ch.send_nowait(word_stream) + word_stream.push_text(input) + elif isinstance(input, self._FlushSentinel): + if word_stream: + word_stream.end_input() + word_stream = None + self._segments_ch.close() + + @utils.log_exceptions(logger=logger) + async def _run_segments(ws: aiohttp.ClientWebSocketResponse): + nonlocal closing_ws + async for word_stream in self._segments_ch: + async for word in word_stream: + speak_msg = {"type": "Speak", "text": f"{word.token} "} + await ws.send_str(json.dumps(speak_msg)) + + # Always flush after a segment + flush_msg = {"type": "Flush"} + await ws.send_str(json.dumps(flush_msg)) + + # after all segments, close + close_msg = {"type": "Close"} + closing_ws = True + await ws.send_str(json.dumps(close_msg)) + + async def recv_task(ws: aiohttp.ClientWebSocketResponse): + last_frame: rtc.AudioFrame | None = None + + def _send_last_frame(*, segment_id: str, is_final: bool) -> None: + nonlocal last_frame + if last_frame is not None: + self._event_ch.send_nowait( + tts.SynthesizedAudio( + request_id=request_id, + segment_id=segment_id, + frame=last_frame, + is_final=is_final, + ) + ) + last_frame = None + + while True: + msg = await ws.receive() + if msg.type in ( + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSING, + ): + if not closing_ws: + raise APIStatusError( + "Deepgram websocket connection closed unexpectedly", + request_id=request_id, + ) + return + + if msg.type == aiohttp.WSMsgType.BINARY: + data = msg.data + for frame in audio_bstream.write(data): + _send_last_frame(segment_id=segment_id, is_final=False) + last_frame = frame + elif msg.type == aiohttp.WSMsgType.TEXT: + resp = json.loads(msg.data) + mtype = resp.get("type") + if mtype == "Flushed": + for frame in audio_bstream.flush(): + _send_last_frame(segment_id=segment_id, is_final=False) + last_frame = frame + _send_last_frame(segment_id=segment_id, is_final=True) + elif mtype == "Warning": + logger.warning("Deepgram warning: %s", resp.get("warn_msg")) + elif mtype == "Metadata": + pass + else: + logger.debug("Unknown message type: %s", resp) + + async def _connection_timeout(): + # Deepgram has a 60-minute timeout period for websocket connections + await asyncio.sleep(3300) + logger.warning( + "Deepgram TTS maximum connection time reached. Reconnecting..." + ) + self._reconnect_event.set() + + ws: aiohttp.ClientWebSocketResponse | None = None + while True: + try: + config = { + "encoding": self._opts.encoding, + "model": self._opts.model, + "sample_rate": self._opts.sample_rate, + } + ws = await asyncio.wait_for( + self._session.ws_connect( + _to_deepgram_url(config, self._base_url, websocket=True), + headers={"Authorization": f"Token {self._api_key}"}, + ), + self._conn_options.timeout, + ) + closing_ws = False + + tasks = [ + asyncio.create_task(_tokenize_input()), + asyncio.create_task(_run_segments(ws)), + asyncio.create_task(recv_task(ws)), + ] + wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) + connection_timeout_task = asyncio.create_task(_connection_timeout()) + + try: + done, _ = await asyncio.wait( + [ + asyncio.gather(*tasks), + wait_reconnect_task, + connection_timeout_task, + ], + return_when=asyncio.FIRST_COMPLETED, + ) # type: ignore + if wait_reconnect_task not in done: + break + self._reconnect_event.clear() + finally: + await utils.aio.gracefully_cancel( + *tasks, wait_reconnect_task, connection_timeout_task + ) + + except asyncio.TimeoutError as e: + raise APITimeoutError() from e + except aiohttp.ClientResponseError as e: + raise APIStatusError( + message=e.message, + status_code=e.status, + request_id=request_id, + body=None, + ) from e + except Exception as e: + raise APIConnectionError() from e + finally: + if ws is not None and not ws.closed: + await ws.close() + + +def _to_deepgram_url( + opts: dict, + base_url: str, + *, + websocket: bool, +) -> str: + if websocket and base_url.startswith("http"): + base_url = base_url.replace("http", "ws", 1) + + elif not websocket and base_url.startswith("ws"): + base_url = base_url.replace("ws", "http", 1) + + return f"{base_url}?{urlencode(opts, doseq=True)}" diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/version.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/version.py index a86319c6e..e1df9b637 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/version.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.6.11" +__version__ = "0.6.16" diff --git a/livekit-plugins/livekit-plugins-deepgram/package.json b/livekit-plugins/livekit-plugins-deepgram/package.json index dfdd57a9a..3a0a81159 100644 --- a/livekit-plugins/livekit-plugins-deepgram/package.json +++ b/livekit-plugins/livekit-plugins-deepgram/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-deepgram", "private": true, - "version": "0.6.11" + "version": "0.6.16" } diff --git a/livekit-plugins/livekit-plugins-deepgram/setup.py b/livekit-plugins/livekit-plugins-deepgram/setup.py index 077c6d659..b9316b839 100644 --- a/livekit-plugins/livekit-plugins-deepgram/setup.py +++ b/livekit-plugins/livekit-plugins-deepgram/setup.py @@ -47,7 +47,7 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents>=0.11.3", "numpy~=1.21"], + install_requires=["livekit-agents>=0.12.3", "numpy>=1.26"], package_data={"livekit.plugins.deepgram": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io", diff --git a/livekit-plugins/livekit-plugins-elevenlabs/CHANGELOG.md b/livekit-plugins/livekit-plugins-elevenlabs/CHANGELOG.md index cdadc5f20..a9ce20173 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-elevenlabs/CHANGELOG.md @@ -1,5 +1,17 @@ # livekit-plugins-elevenlabs +## 0.7.9 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.7.8 + +### Patch Changes + +- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom)) + ## 0.7.7 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py index 0c5490707..948d42758 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py @@ -469,8 +469,9 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None: aiohttp.WSMsgType.CLOSING, ): if not eos_sent: - raise Exception( - "11labs connection closed unexpectedly, not all tokens have been consumed" + raise APIStatusError( + "11labs connection closed unexpectedly, not all tokens have been consumed", + request_id=request_id, ) return diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/version.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/version.py index 32297f2b5..632574328 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/version.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.7.7" +__version__ = "0.7.9" diff --git a/livekit-plugins/livekit-plugins-elevenlabs/package.json b/livekit-plugins/livekit-plugins-elevenlabs/package.json index 844815f53..272386ca0 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/package.json +++ b/livekit-plugins/livekit-plugins-elevenlabs/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-elevenlabs", "private": true, - "version": "0.7.7" + "version": "0.7.9" } diff --git a/livekit-plugins/livekit-plugins-elevenlabs/setup.py b/livekit-plugins/livekit-plugins-elevenlabs/setup.py index ba5400e84..829739fe2 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/setup.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/setup.py @@ -49,7 +49,7 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents[codecs]>=0.11"], + install_requires=["livekit-agents[codecs]>=0.12.3"], package_data={"livekit.plugins.elevenlabs": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io", diff --git a/livekit-plugins/livekit-plugins-fal/CHANGELOG.md b/livekit-plugins/livekit-plugins-fal/CHANGELOG.md index 1703b449b..d0b2ba536 100644 --- a/livekit-plugins/livekit-plugins-fal/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-fal/CHANGELOG.md @@ -1,5 +1,17 @@ # livekit-plugins-fal +## 0.2.2 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.2.1 + +### Patch Changes + +- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom)) + ## 0.2.0 ### Minor Changes diff --git a/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py b/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py index fc275ed21..cca983ef8 100644 --- a/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py +++ b/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py @@ -1,16 +1,19 @@ +from __future__ import annotations + import dataclasses import os from dataclasses import dataclass from typing import Optional import fal_client +from livekit import rtc from livekit.agents import ( APIConnectionError, + APIConnectOptions, stt, ) from livekit.agents.stt import SpeechEventType, STTCapabilities -from livekit.agents.utils import AudioBuffer, merge_frames -from livekit.rtc import AudioFrame +from livekit.agents.utils import AudioBuffer @dataclass @@ -47,6 +50,9 @@ def __init__( "FAL AI API key is required. It should be set with env FAL_KEY" ) + def update_options(self, *, language: Optional[str] = None) -> None: + self._opts.language = language or self._opts.language + def _sanitize_options( self, *, @@ -66,18 +72,14 @@ async def _recognize_impl( self, buffer: AudioBuffer, *, - language: Optional[str] = None, - task: Optional[str] = None, - chunk_level: Optional[str] = None, - version: Optional[str] = None, + language: str | None, + conn_options: APIConnectOptions, ) -> stt.SpeechEvent: try: - config = self._sanitize_options( - language=language, task=task, chunk_level=chunk_level, version=version + config = self._sanitize_options(language=language) + data_uri = fal_client.encode( + rtc.combine_audio_frames(buffer).to_wav_bytes(), "audio/x-wav" ) - buffer = merge_frames(buffer) - wav_bytes = AudioFrame.to_wav_bytes(buffer) - data_uri = fal_client.encode(wav_bytes, "audio/x-wav") response = await self._fal_client.run( "fal-ai/wizper", arguments={ @@ -87,6 +89,7 @@ async def _recognize_impl( "chunk_level": config.chunk_level, "version": config.version, }, + timeout=conn_options.timeout, ) text = response.get("text", "") return self._transcription_to_speech_event(text=text) diff --git a/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/version.py b/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/version.py index eaa4231b0..1b935518b 100644 --- a/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/version.py +++ b/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.0" +__version__ = "0.2.2" diff --git a/livekit-plugins/livekit-plugins-fal/package.json b/livekit-plugins/livekit-plugins-fal/package.json index 006e4b94f..0e05e63ec 100644 --- a/livekit-plugins/livekit-plugins-fal/package.json +++ b/livekit-plugins/livekit-plugins-fal/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-fal", "private": true, - "version": "0.2.0" + "version": "0.2.2" } diff --git a/livekit-plugins/livekit-plugins-fal/setup.py b/livekit-plugins/livekit-plugins-fal/setup.py index 014251d0c..760607daf 100644 --- a/livekit-plugins/livekit-plugins-fal/setup.py +++ b/livekit-plugins/livekit-plugins-fal/setup.py @@ -47,7 +47,7 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents>=0.11", "fal_client"], + install_requires=["livekit-agents>=0.12.3", "fal_client"], package_data={"livekit.plugins.fal": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io", diff --git a/livekit-plugins/livekit-plugins-google/CHANGELOG.md b/livekit-plugins/livekit-plugins-google/CHANGELOG.md index dac20edd9..8867829ea 100644 --- a/livekit-plugins/livekit-plugins-google/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-google/CHANGELOG.md @@ -1,5 +1,37 @@ # livekit-plugins-google +## 0.9.0 + +### Minor Changes + +- make multimodal class generic and support gemini live api - [#1240](https://github.com/livekit/agents/pull/1240) ([@jayeshp19](https://github.com/jayeshp19)) + +### Patch Changes + +- fix: Ensure STT exceptions are being propagated - [#1291](https://github.com/livekit/agents/pull/1291) ([@davidzhao](https://github.com/davidzhao)) + +## 0.8.1 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.8.0 + +### Minor Changes + +- Add support for google STT chirp_2 model. - [#1089](https://github.com/livekit/agents/pull/1089) ([@brightsparc](https://github.com/brightsparc)) + +### Patch Changes + +- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom)) + +- fix: add retry logic for google stt abort exception - [#1100](https://github.com/livekit/agents/pull/1100) ([@jayeshp19](https://github.com/jayeshp19)) + +- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom)) + +- google STT - use the baseclass resampler - [#1106](https://github.com/livekit/agents/pull/1106) ([@jayeshp19](https://github.com/jayeshp19)) + ## 0.7.3 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-google/README.md b/livekit-plugins/livekit-plugins-google/README.md index b0fffb41e..383fe1a62 100644 --- a/livekit-plugins/livekit-plugins-google/README.md +++ b/livekit-plugins/livekit-plugins-google/README.md @@ -11,3 +11,8 @@ pip install livekit-plugins-google ## Pre-requisites For credentials, you'll need a Google Cloud account and obtain the correct credentials. Credentials can be passed directly or via Application Default Credentials as specified in [How Application Default Credentials works](https://cloud.google.com/docs/authentication/application-default-credentials). + +To use the STT and TTS API, you'll need to enable the respective services for your Google Cloud project. + +- Cloud Speech-to-Text API +- Cloud Text-to-Speech API diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/__init__.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/__init__.py index ca754bd30..88e163634 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/__init__.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/__init__.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import beta from .stt import STT, SpeechStream from .tts import TTS from .version import __version__ -__all__ = ["STT", "TTS", "SpeechStream", "__version__"] - +__all__ = ["STT", "TTS", "SpeechStream", "__version__", "beta"] from livekit.agents import Plugin from .log import logger diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/__init__.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/__init__.py new file mode 100644 index 000000000..89cb122c8 --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/__init__.py @@ -0,0 +1,3 @@ +from . import realtime + +__all__ = ["realtime"] diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py new file mode 100644 index 000000000..e95a86917 --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py @@ -0,0 +1,15 @@ +from .api_proto import ( + ClientEvents, + LiveAPIModels, + ResponseModality, + Voice, +) +from .realtime_api import RealtimeModel + +__all__ = [ + "RealtimeModel", + "ClientEvents", + "LiveAPIModels", + "ResponseModality", + "Voice", +] diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py new file mode 100644 index 000000000..c02fb3859 --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import inspect +from typing import Any, Dict, List, Literal, Sequence, Union + +from google.genai import types # type: ignore + +LiveAPIModels = Literal["gemini-2.0-flash-exp"] + +Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"] +ResponseModality = Literal["AUDIO", "TEXT"] + + +ClientEvents = Union[ + types.ContentListUnion, + types.ContentListUnionDict, + types.LiveClientContentOrDict, + types.LiveClientRealtimeInput, + types.LiveClientRealtimeInputOrDict, + types.LiveClientToolResponseOrDict, + types.FunctionResponseOrDict, + Sequence[types.FunctionResponseOrDict], +] + + +JSON_SCHEMA_TYPE_MAP = { + str: "string", + int: "integer", + float: "number", + bool: "boolean", + dict: "object", + list: "array", +} + + +def _build_parameters(arguments: Dict[str, Any]) -> types.SchemaDict: + properties: Dict[str, types.SchemaDict] = {} + required: List[str] = [] + + for arg_name, arg_info in arguments.items(): + py_type = arg_info.type + if py_type not in JSON_SCHEMA_TYPE_MAP: + raise ValueError(f"Unsupported type: {py_type}") + + prop: types.SchemaDict = { + "type": JSON_SCHEMA_TYPE_MAP[py_type], + "description": arg_info.description, + } + + if arg_info.choices: + prop["enum"] = arg_info.choices + + properties[arg_name] = prop + + if arg_info.default is inspect.Parameter.empty: + required.append(arg_name) + + parameters: types.SchemaDict = {"type": "object", "properties": properties} + + if required: + parameters["required"] = required + + return parameters + + +def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclarationDict]: + function_declarations: List[types.FunctionDeclarationDict] = [] + for fnc_info in fnc_ctx.ai_functions.values(): + parameters = _build_parameters(fnc_info.arguments) + + func_decl: types.FunctionDeclarationDict = { + "name": fnc_info.name, + "description": fnc_info.description, + "parameters": parameters, + } + + function_declarations.append(func_decl) + + return function_declarations diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py new file mode 100644 index 000000000..40bb0d7a1 --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py @@ -0,0 +1,424 @@ +from __future__ import annotations + +import asyncio +import base64 +import json +import os +from dataclasses import dataclass +from typing import AsyncIterable, Literal + +from livekit import rtc +from livekit.agents import llm, utils +from livekit.agents.llm.function_context import _create_ai_function_info + +from google import genai # type: ignore +from google.genai.types import ( # type: ignore + FunctionResponse, + GenerationConfigDict, + LiveClientToolResponse, + LiveConnectConfigDict, + PrebuiltVoiceConfig, + SpeechConfig, + VoiceConfig, +) + +from ...log import logger +from .api_proto import ( + ClientEvents, + LiveAPIModels, + ResponseModality, + Voice, + _build_tools, +) + +EventTypes = Literal[ + "start_session", + "input_speech_started", + "response_content_added", + "response_content_done", + "function_calls_collected", + "function_calls_finished", + "function_calls_cancelled", +] + + +@dataclass +class GeminiContent: + response_id: str + item_id: str + output_index: int + content_index: int + text: str + audio: list[rtc.AudioFrame] + text_stream: AsyncIterable[str] + audio_stream: AsyncIterable[rtc.AudioFrame] + content_type: Literal["text", "audio"] + + +@dataclass +class Capabilities: + supports_truncate: bool + + +@dataclass +class ModelOptions: + model: LiveAPIModels | str + api_key: str | None + voice: Voice | str + response_modalities: ResponseModality + vertexai: bool + project: str | None + location: str | None + candidate_count: int + temperature: float | None + max_output_tokens: int | None + top_p: float | None + top_k: int | None + presence_penalty: float | None + frequency_penalty: float | None + instructions: str + + +class RealtimeModel: + def __init__( + self, + *, + instructions: str = "", + model: LiveAPIModels | str = "gemini-2.0-flash-exp", + api_key: str | None = None, + voice: Voice | str = "Puck", + modalities: ResponseModality = "AUDIO", + vertexai: bool = False, + project: str | None = None, + location: str | None = None, + candidate_count: int = 1, + temperature: float | None = None, + max_output_tokens: int | None = None, + top_p: float | None = None, + top_k: int | None = None, + presence_penalty: float | None = None, + frequency_penalty: float | None = None, + loop: asyncio.AbstractEventLoop | None = None, + ): + """ + Initializes a RealtimeModel instance for interacting with Google's Realtime API. + + Args: + instructions (str, optional): Initial system instructions for the model. Defaults to "". + api_key (str or None, optional): OpenAI API key. If None, will attempt to read from the environment variable OPENAI_API_KEY + modalities (ResponseModality): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"]. + model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp". + voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck". + temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8. + vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False. + project (str or None, optional): The project to use for the API. Defaults to None. (for vertexai) + location (str or None, optional): The location to use for the API. Defaults to None. (for vertexai) + candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1. + top_p (float, optional): The top-p value for response generation + top_k (int, optional): The top-k value for response generation + presence_penalty (float, optional): The presence penalty for response generation + frequency_penalty (float, optional): The frequency penalty for response generation + loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used. + + Raises: + ValueError: If the API key is not provided and cannot be found in environment variables. + """ + super().__init__() + self._capabilities = Capabilities( + supports_truncate=False, + ) + self._model = model + self._loop = loop or asyncio.get_event_loop() + self._api_key = api_key or os.environ.get("GOOGLE_API_KEY") + self._vertexai = vertexai + self._project_id = project or os.environ.get("GOOGLE_PROJECT") + self._location = location or os.environ.get("GOOGLE_LOCATION") + if self._api_key is None and not self._vertexai: + raise ValueError("GOOGLE_API_KEY is not set") + + self._rt_sessions: list[GeminiRealtimeSession] = [] + self._opts = ModelOptions( + model=model, + api_key=api_key, + voice=voice, + response_modalities=modalities, + vertexai=vertexai, + project=project, + location=location, + candidate_count=candidate_count, + temperature=temperature, + max_output_tokens=max_output_tokens, + top_p=top_p, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + instructions=instructions, + ) + + @property + def sessions(self) -> list[GeminiRealtimeSession]: + return self._rt_sessions + + @property + def capabilities(self) -> Capabilities: + return self._capabilities + + def session( + self, + *, + chat_ctx: llm.ChatContext | None = None, + fnc_ctx: llm.FunctionContext | None = None, + ) -> GeminiRealtimeSession: + session = GeminiRealtimeSession( + opts=self._opts, + chat_ctx=chat_ctx or llm.ChatContext(), + fnc_ctx=fnc_ctx, + loop=self._loop, + ) + self._rt_sessions.append(session) + + return session + + async def aclose(self) -> None: + for session in self._rt_sessions: + await session.aclose() + + +class GeminiRealtimeSession(utils.EventEmitter[EventTypes]): + def __init__( + self, + *, + opts: ModelOptions, + chat_ctx: llm.ChatContext, + fnc_ctx: llm.FunctionContext | None, + loop: asyncio.AbstractEventLoop, + ): + """ + Initializes a GeminiRealtimeSession instance for interacting with Google's Realtime API. + + Args: + opts (ModelOptions): The model options for the session. + chat_ctx (llm.ChatContext): The chat context for the session. + fnc_ctx (llm.FunctionContext or None): The function context for the session. + loop (asyncio.AbstractEventLoop): The event loop for the session. + """ + super().__init__() + self._loop = loop + self._opts = opts + self._chat_ctx = chat_ctx + self._fnc_ctx = fnc_ctx + self._fnc_tasks = utils.aio.TaskSet() + + tools = [] + if self._fnc_ctx is not None: + functions = _build_tools(self._fnc_ctx) + tools.append({"function_declarations": functions}) + + self._config = LiveConnectConfigDict( + model=self._opts.model, + response_modalities=self._opts.response_modalities, + generation_config=GenerationConfigDict( + candidate_count=self._opts.candidate_count, + temperature=self._opts.temperature, + max_output_tokens=self._opts.max_output_tokens, + top_p=self._opts.top_p, + top_k=self._opts.top_k, + presence_penalty=self._opts.presence_penalty, + frequency_penalty=self._opts.frequency_penalty, + ), + system_instruction=self._opts.instructions, + speech_config=SpeechConfig( + voice_config=VoiceConfig( + prebuilt_voice_config=PrebuiltVoiceConfig( + voice_name=self._opts.voice + ) + ) + ), + tools=tools, + ) + self._client = genai.Client( + http_options={"api_version": "v1alpha"}, + api_key=self._opts.api_key, + vertexai=self._opts.vertexai, + project=self._opts.project, + location=self._opts.location, + ) + self._main_atask = asyncio.create_task( + self._main_task(), name="gemini-realtime-session" + ) + # dummy task to wait for the session to be initialized # TODO: sync chat ctx + self._init_sync_task = asyncio.create_task( + asyncio.sleep(0), name="gemini-realtime-session-init" + ) + self._send_ch = utils.aio.Chan[ClientEvents]() + self._active_response_id = None + + async def aclose(self) -> None: + if self._send_ch.closed: + return + + self._send_ch.close() + await self._main_atask + + @property + def fnc_ctx(self) -> llm.FunctionContext | None: + return self._fnc_ctx + + @fnc_ctx.setter + def fnc_ctx(self, value: llm.FunctionContext | None) -> None: + self._fnc_ctx = value + + def _push_audio(self, frame: rtc.AudioFrame) -> None: + data = base64.b64encode(frame.data).decode("utf-8") + self._queue_msg({"mime_type": "audio/pcm", "data": data}) + + def _queue_msg(self, msg: dict) -> None: + self._send_ch.send_nowait(msg) + + def chat_ctx_copy(self) -> llm.ChatContext: + return self._chat_ctx.copy() + + async def set_chat_ctx(self, ctx: llm.ChatContext) -> None: + self._chat_ctx = ctx.copy() + + @utils.log_exceptions(logger=logger) + async def _main_task(self): + @utils.log_exceptions(logger=logger) + async def _send_task(): + async for msg in self._send_ch: + await self._session.send(msg) + + await self._session.send(".", end_of_turn=True) + + @utils.log_exceptions(logger=logger) + async def _recv_task(): + while True: + async for response in self._session.receive(): + if self._active_response_id is None: + self._active_response_id = utils.shortuuid() + text_stream = utils.aio.Chan[str]() + audio_stream = utils.aio.Chan[rtc.AudioFrame]() + content = GeminiContent( + response_id=self._active_response_id, + item_id=self._active_response_id, + output_index=0, + content_index=0, + text="", + audio=[], + text_stream=text_stream, + audio_stream=audio_stream, + content_type=self._opts.response_modalities, + ) + self.emit("response_content_added", content) + + server_content = response.server_content + if server_content: + model_turn = server_content.model_turn + if model_turn: + for part in model_turn.parts: + if part.text: + content.text_stream.send_nowait(part.text) + if part.inline_data: + frame = rtc.AudioFrame( + data=part.inline_data.data, + sample_rate=24000, + num_channels=1, + samples_per_channel=len(part.inline_data.data) + // 2, + ) + content.audio_stream.send_nowait(frame) + + if server_content.interrupted or server_content.turn_complete: + for stream in (content.text_stream, content.audio_stream): + if isinstance(stream, utils.aio.Chan): + stream.close() + + if server_content.interrupted: + self.emit("input_speech_started") + elif server_content.turn_complete: + self.emit("response_content_done", content) + + self._active_response_id = None + + if response.tool_call: + if self._fnc_ctx is None: + raise ValueError("Function context is not set") + fnc_calls = [] + for fnc_call in response.tool_call.function_calls: + fnc_call_info = _create_ai_function_info( + self._fnc_ctx, + fnc_call.id, + fnc_call.name, + json.dumps(fnc_call.args), + ) + fnc_calls.append(fnc_call_info) + + self.emit("function_calls_collected", fnc_calls) + + for fnc_call_info in fnc_calls: + self._fnc_tasks.create_task( + self._run_fnc_task(fnc_call_info, content.item_id) + ) + + # Handle function call cancellations + if response.tool_call_cancellation: + logger.warning( + "function call cancelled", + extra={ + "function_call_ids": response.tool_call_cancellation.function_call_ids, + }, + ) + self.emit( + "function_calls_cancelled", + response.tool_call_cancellation.function_call_ids, + ) + + async with self._client.aio.live.connect( + model=self._opts.model, config=self._config + ) as session: + self._session = session + tasks = [ + asyncio.create_task(_send_task(), name="gemini-realtime-send"), + asyncio.create_task(_recv_task(), name="gemini-realtime-recv"), + ] + + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) + await self._session.close() + + @utils.log_exceptions(logger=logger) + async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str): + logger.debug( + "executing ai function", + extra={ + "function": fnc_call_info.function_info.name, + }, + ) + + called_fnc = fnc_call_info.execute() + try: + await called_fnc.task + except Exception as e: + logger.exception( + "error executing ai function", + extra={ + "function": fnc_call_info.function_info.name, + }, + exc_info=e, + ) + tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc) + if tool_call.content is not None: + tool_response = LiveClientToolResponse( + function_responses=[ + FunctionResponse( + name=tool_call.name, + id=tool_call.tool_call_id, + response={"result": tool_call.content}, + ) + ] + ) + await self._session.send(tool_response) + + self.emit("function_calls_finished", [called_fnc]) diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py index 3272fcfcc..7fe2a527d 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py @@ -16,12 +16,15 @@ import asyncio import dataclasses +import weakref from dataclasses import dataclass -from typing import AsyncIterable, List, Union +from typing import List, Union -from livekit import agents, rtc +from livekit import rtc from livekit.agents import ( + DEFAULT_API_CONNECT_OPTIONS, APIConnectionError, + APIConnectOptions, APIStatusError, APITimeoutError, stt, @@ -29,7 +32,7 @@ ) from google.api_core.client_options import ClientOptions -from google.api_core.exceptions import Aborted, DeadlineExceeded, GoogleAPICallError +from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError from google.auth import default as gauth_default from google.auth.exceptions import DefaultCredentialsError from google.cloud.speech_v2 import SpeechAsyncClient @@ -51,6 +54,7 @@ class STTOptions: punctuate: bool spoken_punctuation: bool model: SpeechModels + sample_rate: int keywords: List[tuple[str, float]] | None def build_adaptation(self) -> cloud_speech.SpeechAdaptation | None: @@ -83,6 +87,7 @@ def __init__( spoken_punctuation: bool = True, model: SpeechModels = "long", location: str = "global", + sample_rate: int = 16000, credentials_info: dict | None = None, credentials_file: str | None = None, keywords: List[tuple[str, float]] | None = None, @@ -123,8 +128,10 @@ def __init__( punctuate=punctuate, spoken_punctuation=spoken_punctuation, model=model, + sample_rate=sample_rate, keywords=keywords, ) + self._streams = weakref.WeakSet[SpeechStream]() def _ensure_client(self) -> SpeechAsyncClient: if self._credentials_info: @@ -183,10 +190,11 @@ async def _recognize_impl( self, buffer: utils.AudioBuffer, *, - language: SpeechLanguages | str | None = None, + language: SpeechLanguages | str | None, + conn_options: APIConnectOptions, ) -> stt.SpeechEvent: config = self._sanitize_options(language=language) - frame = agents.utils.merge_frames(buffer) + frame = rtc.combine_audio_frames(buffer) config = cloud_speech.RecognitionConfig( explicit_decoding_config=cloud_speech.ExplicitDecodingConfig( @@ -210,7 +218,8 @@ async def _recognize_impl( recognizer=self._recognizer, config=config, content=frame.data.tobytes(), - ) + ), + timeout=conn_options.timeout, ) return _recognize_response_to_speech_event(raw) @@ -227,149 +236,223 @@ async def _recognize_impl( raise APIConnectionError() from e def stream( - self, *, language: SpeechLanguages | str | None = None + self, + *, + language: SpeechLanguages | str | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> "SpeechStream": config = self._sanitize_options(language=language) - return SpeechStream(self, self._ensure_client(), self._recognizer, config) + stream = SpeechStream( + stt=self, + client=self._ensure_client(), + recognizer=self._recognizer, + config=config, + conn_options=conn_options, + ) + self._streams.add(stream) + return stream + + def update_options( + self, + *, + languages: LanguageCode | None = None, + detect_language: bool | None = None, + interim_results: bool | None = None, + punctuate: bool | None = None, + spoken_punctuation: bool | None = None, + model: SpeechModels | None = None, + location: str | None = None, + keywords: List[tuple[str, float]] | None = None, + ): + if languages is not None: + if isinstance(languages, str): + languages = [languages] + self._config.languages = languages + if detect_language is not None: + self._config.detect_language = detect_language + if interim_results is not None: + self._config.interim_results = interim_results + if punctuate is not None: + self._config.punctuate = punctuate + if spoken_punctuation is not None: + self._config.spoken_punctuation = spoken_punctuation + if model is not None: + self._config.model = model + if keywords is not None: + self._config.keywords = keywords + + for stream in self._streams: + stream.update_options( + languages=languages, + detect_language=detect_language, + interim_results=interim_results, + punctuate=punctuate, + spoken_punctuation=spoken_punctuation, + model=model, + location=location, + keywords=keywords, + ) class SpeechStream(stt.SpeechStream): def __init__( self, + *, stt: STT, + conn_options: APIConnectOptions, client: SpeechAsyncClient, recognizer: str, config: STTOptions, - sample_rate: int = 16000, - num_channels: int = 1, - max_retry: int = 32, ) -> None: - super().__init__(stt, sample_rate=sample_rate) + super().__init__( + stt=stt, conn_options=conn_options, sample_rate=config.sample_rate + ) self._client = client self._recognizer = recognizer self._config = config - self._sample_rate = sample_rate - self._num_channels = num_channels - self._max_retry = max_retry - - self._streaming_config = cloud_speech.StreamingRecognitionConfig( - config=cloud_speech.RecognitionConfig( - explicit_decoding_config=cloud_speech.ExplicitDecodingConfig( - encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16, - sample_rate_hertz=self._sample_rate, - audio_channel_count=self._num_channels, - ), - adaptation=config.build_adaptation(), - language_codes=self._config.languages, - model=self._config.model, - features=cloud_speech.RecognitionFeatures( - enable_automatic_punctuation=self._config.punctuate, - enable_word_time_offsets=True, - ), - ), - streaming_features=cloud_speech.StreamingRecognitionFeatures( - enable_voice_activity_events=True, - interim_results=self._config.interim_results, - ), - ) - - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - await self._run(self._max_retry) + self._reconnect_event = asyncio.Event() - async def _run(self, max_retry: int) -> None: - retry_count = 0 - while self._input_ch.qsize() or not self._input_ch.closed: + def update_options( + self, + *, + languages: LanguageCode | None = None, + detect_language: bool | None = None, + interim_results: bool | None = None, + punctuate: bool | None = None, + spoken_punctuation: bool | None = None, + model: SpeechModels | None = None, + location: str | None = None, + keywords: List[tuple[str, float]] | None = None, + ): + if languages is not None: + if isinstance(languages, str): + languages = [languages] + self._config.languages = languages + if detect_language is not None: + self._config.detect_language = detect_language + if interim_results is not None: + self._config.interim_results = interim_results + if punctuate is not None: + self._config.punctuate = punctuate + if spoken_punctuation is not None: + self._config.spoken_punctuation = spoken_punctuation + if model is not None: + self._config.model = model + if keywords is not None: + self._config.keywords = keywords + + self._reconnect_event.set() + + async def _run(self) -> None: + # google requires a async generator when calling streaming_recognize + # this function basically convert the queue into a async generator + async def input_generator(): try: - # google requires a async generator when calling streaming_recognize - # this function basically convert the queue into a async generator - async def input_generator(): - try: - # first request should contain the config + # first request should contain the config + yield cloud_speech.StreamingRecognizeRequest( + recognizer=self._recognizer, + streaming_config=self._streaming_config, + ) + + async for frame in self._input_ch: + if isinstance(frame, rtc.AudioFrame): yield cloud_speech.StreamingRecognizeRequest( - recognizer=self._recognizer, - streaming_config=self._streaming_config, + audio=frame.data.tobytes() ) - async for frame in self._input_ch: - if isinstance(frame, rtc.AudioFrame): - yield cloud_speech.StreamingRecognizeRequest( - audio=frame.data.tobytes() - ) + except Exception: + logger.exception( + "an error occurred while streaming input to google STT" + ) - except Exception: - logger.exception( - "an error occurred while streaming input to google STT" + async def process_stream(stream): + async for resp in stream: + if ( + resp.speech_event_type + == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN + ): + self._event_ch.send_nowait( + stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH) + ) + + if ( + resp.speech_event_type + == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED + ): + result = resp.results[0] + speech_data = _streaming_recognize_response_to_speech_data(resp) + if speech_data is None: + continue + + if not result.is_final: + self._event_ch.send_nowait( + stt.SpeechEvent( + type=stt.SpeechEventType.INTERIM_TRANSCRIPT, + alternatives=[speech_data], + ) + ) + else: + self._event_ch.send_nowait( + stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[speech_data], + ) ) - # try to connect - stream = await self._client.streaming_recognize( - requests=input_generator() - ) - retry_count = 0 # connection successful, reset retry count - - await self._run_stream(stream) - except Exception as e: - error_type = "Aborted" if isinstance(e, Aborted) else "Error" - if retry_count >= max_retry: - logger.error( - f"failed to connect to google stt after {max_retry} tries due to {error_type}", - exc_info=e, + if ( + resp.speech_event_type + == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END + ): + self._event_ch.send_nowait( + stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH) ) - break - retry_delay = min(retry_count * 2, 5) # max 5s - retry_count += 1 - logger.warning( - f"google stt connection {error_type.lower()}, retrying in {retry_delay}s", - exc_info=e, + while True: + try: + self._streaming_config = cloud_speech.StreamingRecognitionConfig( + config=cloud_speech.RecognitionConfig( + explicit_decoding_config=cloud_speech.ExplicitDecodingConfig( + encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16, + sample_rate_hertz=self._config.sample_rate, + audio_channel_count=1, + ), + adaptation=self._config.build_adaptation(), + language_codes=self._config.languages, + model=self._config.model, + features=cloud_speech.RecognitionFeatures( + enable_automatic_punctuation=self._config.punctuate, + enable_word_time_offsets=True, + ), + ), + streaming_features=cloud_speech.StreamingRecognitionFeatures( + enable_voice_activity_events=True, + interim_results=self._config.interim_results, + ), ) - await asyncio.sleep(retry_delay) - async def _run_stream( - self, stream: AsyncIterable[cloud_speech.StreamingRecognizeResponse] - ): - async for resp in stream: - if ( - resp.speech_event_type - == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN - ): - self._event_ch.send_nowait( - stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH) + stream = await self._client.streaming_recognize( + requests=input_generator(), ) - if ( - resp.speech_event_type - == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED - ): - result = resp.results[0] - speech_data = _streaming_recognize_response_to_speech_data(resp) - if speech_data is None: - continue - - if not result.is_final: - self._event_ch.send_nowait( - stt.SpeechEvent( - type=stt.SpeechEventType.INTERIM_TRANSCRIPT, - alternatives=[speech_data], - ) + process_stream_task = asyncio.create_task(process_stream(stream)) + wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) + try: + done, _ = await asyncio.wait( + [process_stream_task, wait_reconnect_task], + return_when=asyncio.FIRST_COMPLETED, ) - else: - self._event_ch.send_nowait( - stt.SpeechEvent( - type=stt.SpeechEventType.FINAL_TRANSCRIPT, - alternatives=[speech_data], - ) + for task in done: + if task != wait_reconnect_task: + task.result() + finally: + await utils.aio.gracefully_cancel( + process_stream_task, wait_reconnect_task ) - - if ( - resp.speech_event_type - == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END - ): - self._event_ch.send_nowait( - stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH) - ) + finally: + if not self._reconnect_event.is_set(): + break + self._reconnect_event.clear() def _recognize_response_to_speech_event( diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/version.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/version.py index 20d8a2226..654ad56ec 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/version.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.7.3" +__version__ = "0.9.0" diff --git a/livekit-plugins/livekit-plugins-google/package.json b/livekit-plugins/livekit-plugins-google/package.json index 38ac1d046..17bc59ac6 100644 --- a/livekit-plugins/livekit-plugins-google/package.json +++ b/livekit-plugins/livekit-plugins-google/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-google", "private": true, - "version": "0.7.3" + "version": "0.9.0" } diff --git a/livekit-plugins/livekit-plugins-google/setup.py b/livekit-plugins/livekit-plugins-google/setup.py index b6e72949b..0db8addce 100644 --- a/livekit-plugins/livekit-plugins-google/setup.py +++ b/livekit-plugins/livekit-plugins-google/setup.py @@ -51,7 +51,8 @@ "google-auth >= 2, < 3", "google-cloud-speech >= 2, < 3", "google-cloud-texttospeech >= 2, < 3", - "livekit-agents>=0.11", + "google-genai >= 0.3.0", + "livekit-agents>=0.12.3", ], package_data={"livekit.plugins.google": ["py.typed"]}, project_urls={ diff --git a/livekit-plugins/livekit-plugins-llama-index/CHANGELOG.md b/livekit-plugins/livekit-plugins-llama-index/CHANGELOG.md index 952e0a4ff..6600ee049 100644 --- a/livekit-plugins/livekit-plugins-llama-index/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-llama-index/CHANGELOG.md @@ -1,5 +1,19 @@ # livekit-plugins-llama-index +## 0.2.2 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.2.1 + +### Patch Changes + +- support for custom tool use in LLMs - [#1102](https://github.com/livekit/agents/pull/1102) ([@jayeshp19](https://github.com/jayeshp19)) + +- feat: llm retry & llm.FallbackAdapter - [#1132](https://github.com/livekit/agents/pull/1132) ([@theomonnom](https://github.com/theomonnom)) + ## 0.2.0 ### Minor Changes diff --git a/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py b/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py index f07c989df..9f674717d 100644 --- a/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py +++ b/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py @@ -1,9 +1,13 @@ from __future__ import annotations +from typing import Literal, Union + from livekit.agents import ( APIConnectionError, llm, ) +from livekit.agents.llm import ToolChoice +from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from llama_index.core.chat_engine.types import ( BaseChatEngine, @@ -27,10 +31,13 @@ def chat( self, *, chat_ctx: llm.ChatContext, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, fnc_ctx: llm.FunctionContext | None = None, temperature: float | None = None, n: int | None = 1, parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] + | None = None, ) -> "LLMStream": if fnc_ctx is not None: logger.warning("fnc_ctx is currently not supported with llama_index.LLM") @@ -39,6 +46,7 @@ def chat( self, chat_engine=self._chat_engine, chat_ctx=chat_ctx, + conn_options=conn_options, ) @@ -49,12 +57,15 @@ def __init__( *, chat_engine: BaseChatEngine, chat_ctx: llm.ChatContext, + conn_options: APIConnectOptions, ) -> None: - super().__init__(llm, chat_ctx=chat_ctx, fnc_ctx=None) + super().__init__( + llm, chat_ctx=chat_ctx, fnc_ctx=None, conn_options=conn_options + ) self._chat_engine = chat_engine self._stream: StreamingAgentChatResponse | None = None - async def _main_task(self) -> None: + async def _run(self) -> None: chat_ctx = self._chat_ctx.copy() user_msg = chat_ctx.messages.pop() diff --git a/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/version.py b/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/version.py index eaa4231b0..1b935518b 100644 --- a/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/version.py +++ b/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.0" +__version__ = "0.2.2" diff --git a/livekit-plugins/livekit-plugins-llama-index/package.json b/livekit-plugins/livekit-plugins-llama-index/package.json index ca7c6577c..67848bd35 100644 --- a/livekit-plugins/livekit-plugins-llama-index/package.json +++ b/livekit-plugins/livekit-plugins-llama-index/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-llama-index", "private": true, - "version": "0.2.0" + "version": "0.2.2" } diff --git a/livekit-plugins/livekit-plugins-llama-index/setup.py b/livekit-plugins/livekit-plugins-llama-index/setup.py index 98b0babab..acc39333d 100644 --- a/livekit-plugins/livekit-plugins-llama-index/setup.py +++ b/livekit-plugins/livekit-plugins-llama-index/setup.py @@ -49,7 +49,7 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents>=0.11"], + install_requires=["livekit-agents>=0.12.3"], package_data={"livekit.plugins.llama_index": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io", diff --git a/livekit-plugins/livekit-plugins-minimal/CHANGELOG.md b/livekit-plugins/livekit-plugins-minimal/CHANGELOG.md index 535fb2bec..88fdfa9e0 100644 --- a/livekit-plugins/livekit-plugins-minimal/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-minimal/CHANGELOG.md @@ -1,5 +1,11 @@ # livekit-plugins-minimal +## 0.2.1 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + ## 0.2.0 ### Minor Changes diff --git a/livekit-plugins/livekit-plugins-minimal/livekit/plugins/minimal/version.py b/livekit-plugins/livekit-plugins-minimal/livekit/plugins/minimal/version.py index eaa4231b0..ae5785b8d 100644 --- a/livekit-plugins/livekit-plugins-minimal/livekit/plugins/minimal/version.py +++ b/livekit-plugins/livekit-plugins-minimal/livekit/plugins/minimal/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.0" +__version__ = "0.2.1" diff --git a/livekit-plugins/livekit-plugins-minimal/package.json b/livekit-plugins/livekit-plugins-minimal/package.json index f48849e1a..48cfb1da2 100644 --- a/livekit-plugins/livekit-plugins-minimal/package.json +++ b/livekit-plugins/livekit-plugins-minimal/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-minimal", "private": true, - "version": "0.2.0" + "version": "0.2.1" } diff --git a/livekit-plugins/livekit-plugins-nltk/CHANGELOG.md b/livekit-plugins/livekit-plugins-nltk/CHANGELOG.md index 6ee2124fe..9d8b746d6 100644 --- a/livekit-plugins/livekit-plugins-nltk/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-nltk/CHANGELOG.md @@ -1,5 +1,11 @@ # livekit-plugins-nltk +## 0.7.3 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + ## 0.7.2 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-nltk/livekit/plugins/nltk/version.py b/livekit-plugins/livekit-plugins-nltk/livekit/plugins/nltk/version.py index d40c15247..20d8a2226 100644 --- a/livekit-plugins/livekit-plugins-nltk/livekit/plugins/nltk/version.py +++ b/livekit-plugins/livekit-plugins-nltk/livekit/plugins/nltk/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.7.2" +__version__ = "0.7.3" diff --git a/livekit-plugins/livekit-plugins-nltk/package.json b/livekit-plugins/livekit-plugins-nltk/package.json index 66a8eb3fa..d0a24735b 100644 --- a/livekit-plugins/livekit-plugins-nltk/package.json +++ b/livekit-plugins/livekit-plugins-nltk/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-nltk", "private": true, - "version": "0.7.2" + "version": "0.7.3" } diff --git a/livekit-plugins/livekit-plugins-openai/CHANGELOG.md b/livekit-plugins/livekit-plugins-openai/CHANGELOG.md index 82f96710b..1e363b412 100644 --- a/livekit-plugins/livekit-plugins-openai/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-openai/CHANGELOG.md @@ -1,5 +1,95 @@ # livekit-plugins-openai +## 0.10.13 + +### Patch Changes + +- improved handling of LLM errors, do not retry if already began - [#1298](https://github.com/livekit/agents/pull/1298) ([@davidzhao](https://github.com/davidzhao)) + +- make multimodal class generic and support gemini live api - [#1240](https://github.com/livekit/agents/pull/1240) ([@jayeshp19](https://github.com/jayeshp19)) + +## 0.10.12 + +### Patch Changes + +- fix unknown `metadata` & `store` fields on OpenAI-like API - [#1276](https://github.com/livekit/agents/pull/1276) ([@theomonnom](https://github.com/theomonnom)) + +## 0.10.11 + +### Patch Changes + +- Moved create_ai_function_info to function_context.py for better reusability and reduce repetation - [#1260](https://github.com/livekit/agents/pull/1260) ([@jayeshp19](https://github.com/jayeshp19)) + +- add on_duplicate option for multimodal agent response create - [#1204](https://github.com/livekit/agents/pull/1204) ([@longcw](https://github.com/longcw)) + +- Add support for OpenAI's "detail" parameter to ChatImage - [#1213](https://github.com/livekit/agents/pull/1213) ([@bcherry](https://github.com/bcherry)) + + Add support for data URLs on ChatImage in the Anthropic plugin. + +- filter out empty message for set chat ctx in realtime model - [#1245](https://github.com/livekit/agents/pull/1245) ([@longcw](https://github.com/longcw)) + +- fix: correctly parse function argument types - [#1221](https://github.com/livekit/agents/pull/1221) ([@jayeshp19](https://github.com/jayeshp19)) + +- add session_updated event for RealtimeSession - [#1253](https://github.com/livekit/agents/pull/1253) ([@longcw](https://github.com/longcw)) + +- added llama 3.3 70b to model definitions - [#1233](https://github.com/livekit/agents/pull/1233) ([@davidzhao](https://github.com/davidzhao)) + +- update default realtime model to gpt-4o-realtime-preview-2024-12-17 - [#1250](https://github.com/livekit/agents/pull/1250) ([@davidzhao](https://github.com/davidzhao)) + +- Fix center_aspect_fit bug, add scale_aspect_fit and scale_aspect_fill resizing options. - [#1222](https://github.com/livekit/agents/pull/1222) ([@bcherry](https://github.com/bcherry)) + + Make scale_aspect_fit the new default resizing option for video frames. + +## 0.10.10 + +### Patch Changes + +- add `google/gemini-2.0-flash-exp` as default model for vertex - [#1214](https://github.com/livekit/agents/pull/1214) ([@jayeshp19](https://github.com/jayeshp19)) + +- emit error event for realtime model - [#1200](https://github.com/livekit/agents/pull/1200) ([@longcw](https://github.com/longcw)) + +- fix: return structured output from func calls - [#1187](https://github.com/livekit/agents/pull/1187) ([@jayeshp19](https://github.com/jayeshp19)) + +- Handle optional func args in tool calls when set to `None` - [#1211](https://github.com/livekit/agents/pull/1211) ([@jayeshp19](https://github.com/jayeshp19)) + +- fix: openai llm retries - [#1196](https://github.com/livekit/agents/pull/1196) ([@theomonnom](https://github.com/theomonnom)) + +- Improvements to end of turn plugin, ensure STT language settings. - [#1195](https://github.com/livekit/agents/pull/1195) ([@davidzhao](https://github.com/davidzhao)) + +- fix: Handle optional func args in tool calls when set to `None` - [#1211](https://github.com/livekit/agents/pull/1211) ([@jayeshp19](https://github.com/jayeshp19)) + +## 0.10.9 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.10.8 + +### Patch Changes + +- fix uncatched OAI errors - [#1158](https://github.com/livekit/agents/pull/1158) ([@theomonnom](https://github.com/theomonnom)) + +- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom)) + +- project id fix for google - [#1115](https://github.com/livekit/agents/pull/1115) ([@jayeshp19](https://github.com/jayeshp19)) + +- Add retries to recover from text mode to audio model for realtime API - [#1121](https://github.com/livekit/agents/pull/1121) ([@longcw](https://github.com/longcw)) + +- Support for Python 3.13, relaxed Pillow version requirement for 10.x - [#1127](https://github.com/livekit/agents/pull/1127) ([@davidzhao](https://github.com/davidzhao)) + +- support for custom tool use in LLMs - [#1102](https://github.com/livekit/agents/pull/1102) ([@jayeshp19](https://github.com/jayeshp19)) + +- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom)) + +- Add new OpenAI realtime voices - [#1116](https://github.com/livekit/agents/pull/1116) ([@bcherry](https://github.com/bcherry)) + +- Expose multimodal agent metrics - [#1080](https://github.com/livekit/agents/pull/1080) ([@longcw](https://github.com/longcw)) + +- feat: llm retry & llm.FallbackAdapter - [#1132](https://github.com/livekit/agents/pull/1132) ([@theomonnom](https://github.com/theomonnom)) + +- vertex ai support with openai library - [#1084](https://github.com/livekit/agents/pull/1084) ([@jayeshp19](https://github.com/jayeshp19)) + ## 0.10.7 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py index a19829685..8dbc3a33e 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py @@ -15,72 +15,13 @@ from __future__ import annotations import inspect -import json import typing from typing import Any from livekit.agents.llm import function_context, llm +from livekit.agents.llm.function_context import _is_optional_type -__all__ = ["build_oai_function_description", "create_ai_function_info"] - - -def create_ai_function_info( - fnc_ctx: function_context.FunctionContext, - tool_call_id: str, - fnc_name: str, - raw_arguments: str, # JSON string -) -> function_context.FunctionCallInfo: - if fnc_name not in fnc_ctx.ai_functions: - raise ValueError(f"AI function {fnc_name} not found") - - parsed_arguments: dict[str, Any] = {} - try: - if raw_arguments: # ignore empty string - parsed_arguments = json.loads(raw_arguments) - except json.JSONDecodeError: - raise ValueError( - f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}" - ) - - fnc_info = fnc_ctx.ai_functions[fnc_name] - - # Ensure all necessary arguments are present and of the correct type. - sanitized_arguments: dict[str, Any] = {} - for arg_info in fnc_info.arguments.values(): - if arg_info.name not in parsed_arguments: - if arg_info.default is inspect.Parameter.empty: - raise ValueError( - f"AI function {fnc_name} missing required argument {arg_info.name}" - ) - continue - - arg_value = parsed_arguments[arg_info.name] - if typing.get_origin(arg_info.type) is not None: - if not isinstance(arg_value, list): - raise ValueError( - f"AI function {fnc_name} argument {arg_info.name} should be a list" - ) - - inner_type = typing.get_args(arg_info.type)[0] - sanitized_value = [ - _sanitize_primitive( - value=v, expected_type=inner_type, choices=arg_info.choices - ) - for v in arg_value - ] - else: - sanitized_value = _sanitize_primitive( - value=arg_value, expected_type=arg_info.type, choices=arg_info.choices - ) - - sanitized_arguments[arg_info.name] = sanitized_value - - return function_context.FunctionCallInfo( - tool_call_id=tool_call_id, - raw_arguments=raw_arguments, - function_info=fnc_info, - arguments=sanitized_arguments, - ) +__all__ = ["build_oai_function_description"] def build_oai_function_description( @@ -103,8 +44,10 @@ def type2str(t: type) -> str: if arg_info.description: p["description"] = arg_info.description - if typing.get_origin(arg_info.type) is list: - inner_type = typing.get_args(arg_info.type)[0] + is_optional, inner_th = _is_optional_type(arg_info.type) + + if typing.get_origin(inner_th) is list: + inner_type = typing.get_args(inner_th)[0] p["type"] = "array" p["items"] = {} p["items"]["type"] = type2str(inner_type) @@ -112,11 +55,14 @@ def type2str(t: type) -> str: if arg_info.choices: p["items"]["enum"] = arg_info.choices else: - p["type"] = type2str(arg_info.type) + p["type"] = type2str(inner_th) if arg_info.choices: p["enum"] = arg_info.choices - if arg_info.type is int and arg_info.choices and capabilities is not None: - if not capabilities.supports_choices_on_int: + if ( + inner_th is int + and capabilities + and not capabilities.supports_choices_on_int + ): raise ValueError( f"Parameter '{arg_info.name}' uses 'choices' with 'int', which is not supported by this model." ) @@ -144,31 +90,3 @@ def type2str(t: type) -> str: }, }, } - - -def _sanitize_primitive( - *, value: Any, expected_type: type, choices: tuple | None -) -> Any: - if expected_type is str: - if not isinstance(value, str): - raise ValueError(f"expected str, got {type(value)}") - elif expected_type in (int, float): - if not isinstance(value, (int, float)): - raise ValueError(f"expected number, got {type(value)}") - - if expected_type is int: - if value % 1 != 0: - raise ValueError("expected int, got float") - - value = int(value) - elif expected_type is float: - value = float(value) - - elif expected_type is bool: - if not isinstance(value, bool): - raise ValueError(f"expected bool, got {type(value)}") - - if choices and value not in choices: - raise ValueError(f"invalid value {value}, not in {choices}") - - return value diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py index 33235f6f3..7df336e89 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py @@ -18,11 +18,13 @@ import json import uuid from dataclasses import dataclass -from typing import Any, Callable, Dict, Literal, MutableSet +from typing import Any, Callable, Dict, Literal, MutableSet, Union import httpx from livekit import rtc from livekit.agents import llm, utils +from livekit.agents.llm import ToolChoice +from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from openai import AsyncAssistantEventHandler, AsyncClient from openai.types.beta.threads import Text, TextDelta @@ -166,10 +168,13 @@ def chat( self, *, chat_ctx: llm.ChatContext, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, fnc_ctx: llm.FunctionContext | None = None, temperature: float | None = None, n: int | None = None, parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] + | None = None, ): if n is not None: logger.warning("OpenAI Assistants does not support the 'n' parameter") @@ -190,6 +195,7 @@ def chat( chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, on_file_uploaded=self._on_file_uploaded, + conn_options=conn_options, ) async def _register_tool_call(self, tool_call_id: str, run_id: str) -> None: @@ -301,8 +307,11 @@ def __init__( fnc_ctx: llm.FunctionContext | None, temperature: float | None, on_file_uploaded: OnFileUploaded | None, + conn_options: APIConnectOptions, ) -> None: - super().__init__(assistant_llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) + super().__init__( + assistant_llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options + ) self._client = client self._temperature = temperature self._on_file_uploaded = on_file_uploaded @@ -317,7 +326,7 @@ def __init__( # Running stream is used to ensure that we only have one stream running at a time self._done_future: asyncio.Future[None] = asyncio.Future() - async def _main_task(self) -> None: + async def _run(self) -> None: assert isinstance(self._llm, AssistantLLM) # This function's complexity is due to the fact that we need to sync chat_ctx messages with OpenAI. @@ -523,7 +532,7 @@ async def _upload_frame( opts.resize_options = utils.images.ResizeOptions( width=inference_width, height=inference_height, - strategy="center_aspect_fit", + strategy="scale_aspect_fit", ) encoded_data = utils.images.encode(frame, opts) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index bdb3f4234..37526dd4b 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -19,7 +19,7 @@ import datetime import os from dataclasses import dataclass -from typing import Any, Awaitable, MutableSet +from typing import Any, Literal, MutableSet, Union import aiohttp import httpx @@ -29,15 +29,14 @@ APITimeoutError, llm, ) +from livekit.agents.llm import ToolChoice, _create_ai_function_info +from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions import openai from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam from openai.types.chat.chat_completion_chunk import Choice -from ._oai_api import ( - build_oai_function_description, - create_ai_function_info, -) +from ._oai_api import build_oai_function_description from .log import logger from .models import ( CerebrasChatModels, @@ -59,6 +58,10 @@ class LLMOptions: model: str | ChatModels user: str | None temperature: float | None + parallel_tool_calls: bool | None + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto" + store: bool | None = None + metadata: dict[str, str] | None = None class LLM(llm.LLM): @@ -71,6 +74,10 @@ def __init__( user: str | None = None, client: openai.AsyncClient | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", + store: bool | None = None, + metadata: dict[str, str] | None = None, ) -> None: """ Create a new instance of OpenAI LLM. @@ -81,10 +88,19 @@ def __init__( super().__init__() self._capabilities = llm.LLMCapabilities(supports_choices_on_int=True) - self._opts = LLMOptions(model=model, user=user, temperature=temperature) - self._client: openai.AsyncClient = client or openai.AsyncClient( + self._opts = LLMOptions( + model=model, + user=user, + temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, + store=store, + metadata=metadata, + ) + self._client = client or openai.AsyncClient( api_key=api_key, base_url=base_url, + max_retries=0, http_client=httpx.AsyncClient( timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0), follow_redirects=True, @@ -112,6 +128,8 @@ def with_azure( base_url: str | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ This automatically infers the following arguments from their corresponding environment variables if they are not provided: @@ -124,6 +142,7 @@ def with_azure( """ azure_client = openai.AsyncAzureOpenAI( + max_retries=0, azure_endpoint=azure_endpoint, azure_deployment=azure_deployment, api_version=api_version, @@ -135,7 +154,14 @@ def with_azure( base_url=base_url, ) # type: ignore - return LLM(model=model, client=azure_client, user=user, temperature=temperature) + return LLM( + model=model, + client=azure_client, + user=user, + temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, + ) @staticmethod def with_cerebras( @@ -146,12 +172,15 @@ def with_cerebras( client: openai.AsyncClient | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ Create a new instance of Cerebras LLM. ``api_key`` must be set to your Cerebras API key, either using the argument or by setting the ``CEREBRAS_API_KEY`` environmental variable. + @integrations:cerebras:llm """ api_key = api_key or os.environ.get("CEREBRAS_API_KEY") @@ -167,16 +196,20 @@ def with_cerebras( client=client, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @staticmethod def with_vertex( *, - model: str | VertexModels = "google/gemini-1.5-pro", + model: str | VertexModels = "google/gemini-2.0-flash-exp", project_id: str | None = None, location: str = "us-central1", user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ Create a new instance of VertexAI LLM. @@ -187,8 +220,8 @@ def with_vertex( location = location _gac = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") if _gac is None: - raise ValueError( - "`GOOGLE_APPLICATION_CREDENTIALS` environment variable is not set. please set it to the path of the service account key file." + logger.warning( + "`GOOGLE_APPLICATION_CREDENTIALS` environment variable is not set. please set it to the path of the service account key file. Otherwise, use any of the other Google Cloud auth methods." ) try: @@ -228,6 +261,7 @@ async def _refresh_credentials(self) -> None: self.api_key = self.creds.token client = AuthTokenRefresher( + max_retries=0, http_client=httpx.AsyncClient( timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0), follow_redirects=True, @@ -244,6 +278,8 @@ async def _refresh_credentials(self) -> None: client=client, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) vertex_llm._capabilities = llm.LLMCapabilities(supports_choices_on_int=False) return vertex_llm @@ -251,12 +287,14 @@ async def _refresh_credentials(self) -> None: @staticmethod def with_fireworks( *, - model: str = "accounts/fireworks/models/llama-v3p1-70b-instruct", + model: str = "accounts/fireworks/models/llama-v3p3-70b-instruct", api_key: str | None = None, base_url: str | None = "https://api.fireworks.ai/inference/v1", client: openai.AsyncClient | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ Create a new instance of Fireworks LLM. @@ -278,6 +316,8 @@ def with_fireworks( client=client, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @staticmethod @@ -289,6 +329,8 @@ def with_x_ai( client: openai.AsyncClient | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ): """ Create a new instance of XAI LLM. @@ -309,6 +351,8 @@ def with_x_ai( client=client, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @staticmethod @@ -320,6 +364,8 @@ def with_groq( client: openai.AsyncClient | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ Create a new instance of Groq LLM. @@ -341,6 +387,8 @@ def with_groq( client=client, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @staticmethod @@ -352,6 +400,8 @@ def with_deepseek( client: openai.AsyncClient | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ Create a new instance of DeepSeek LLM. @@ -373,6 +423,8 @@ def with_deepseek( client=client, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @staticmethod @@ -384,6 +436,8 @@ def with_octo( client: openai.AsyncClient | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ Create a new instance of OctoAI LLM. @@ -405,6 +459,8 @@ def with_octo( client=client, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @staticmethod @@ -414,6 +470,8 @@ def with_ollama( base_url: str | None = "http://localhost:11434/v1", client: openai.AsyncClient | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ Create a new instance of Ollama LLM. @@ -425,6 +483,8 @@ def with_ollama( base_url=base_url, client=client, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @staticmethod @@ -436,6 +496,8 @@ def with_perplexity( client: openai.AsyncClient | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ Create a new instance of PerplexityAI LLM. @@ -457,6 +519,8 @@ def with_perplexity( client=client, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @staticmethod @@ -468,6 +532,8 @@ def with_together( client: openai.AsyncClient | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ Create a new instance of TogetherAI LLM. @@ -489,6 +555,8 @@ def with_together( client=client, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @staticmethod @@ -500,6 +568,8 @@ def with_telnyx( client: openai.AsyncClient | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: """ Create a new instance of Telnyx LLM. @@ -521,6 +591,8 @@ def with_telnyx( client=client, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @staticmethod @@ -538,6 +610,8 @@ def create_azure_client( base_url: str | None = None, user: str | None = None, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto", ) -> LLM: logger.warning("This alias is deprecated. Use LLM.with_azure() instead") return LLM.with_azure( @@ -552,80 +626,137 @@ def create_azure_client( base_url=base_url, user=user, temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) def chat( self, *, chat_ctx: llm.ChatContext, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, fnc_ctx: llm.FunctionContext | None = None, temperature: float | None = None, n: int | None = 1, parallel_tool_calls: bool | None = None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] + | None = None, ) -> "LLMStream": - opts: dict[str, Any] = dict() - if fnc_ctx and len(fnc_ctx.ai_functions) > 0: - fncs_desc = [] - for fnc in fnc_ctx.ai_functions.values(): - fncs_desc.append(build_oai_function_description(fnc, self.capabilities)) - - opts["tools"] = fncs_desc + if parallel_tool_calls is None: + parallel_tool_calls = self._opts.parallel_tool_calls - if fnc_ctx and parallel_tool_calls is not None: - opts["parallel_tool_calls"] = parallel_tool_calls + if tool_choice is None: + tool_choice = self._opts.tool_choice - user = self._opts.user or openai.NOT_GIVEN if temperature is None: temperature = self._opts.temperature - messages = _build_oai_context(chat_ctx, id(self)) - - cmp = self._client.chat.completions.create( - messages=messages, + return LLMStream( + self, + client=self._client, model=self._opts.model, + user=self._opts.user, + chat_ctx=chat_ctx, + fnc_ctx=fnc_ctx, + conn_options=conn_options, n=n, temperature=temperature, - stream_options={"include_usage": True}, - stream=True, - user=user, - **opts, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) - return LLMStream(self, oai_stream=cmp, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) - class LLMStream(llm.LLMStream): def __init__( self, llm: LLM, *, - oai_stream: Awaitable[openai.AsyncStream[ChatCompletionChunk]], + client: openai.AsyncClient, + model: str | ChatModels, + user: str | None, chat_ctx: llm.ChatContext, + conn_options: APIConnectOptions, fnc_ctx: llm.FunctionContext | None, + temperature: float | None, + n: int | None, + parallel_tool_calls: bool | None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]], ) -> None: - super().__init__(llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) + super().__init__( + llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options + ) + self._client = client + self._model = model self._llm: LLM = llm - self._awaitable_oai_stream = oai_stream - self._oai_stream: openai.AsyncStream[ChatCompletionChunk] | None = None + + self._user = user + self._temperature = temperature + self._n = n + self._parallel_tool_calls = parallel_tool_calls + self._tool_choice = tool_choice + + async def _run(self) -> None: + if hasattr(self._llm._client, "_refresh_credentials"): + await self._llm._client._refresh_credentials() # current function call that we're waiting for full completion (args are streamed) + # (defined inside the _run method to make sure the state is reset for each run/attempt) + self._oai_stream: openai.AsyncStream[ChatCompletionChunk] | None = None self._tool_call_id: str | None = None self._fnc_name: str | None = None self._fnc_raw_arguments: str | None = None self._tool_index: int | None = None - - async def _main_task(self) -> None: - if hasattr(self._llm._client, "_refresh_credentials"): - await self._llm._client._refresh_credentials() - if not self._oai_stream: - self._oai_stream = await self._awaitable_oai_stream + retryable = True try: - async with self._oai_stream as stream: + opts: dict[str, Any] = dict() + if self._fnc_ctx and len(self._fnc_ctx.ai_functions) > 0: + fncs_desc = [] + for fnc in self._fnc_ctx.ai_functions.values(): + fncs_desc.append( + build_oai_function_description(fnc, self._llm._capabilities) + ) + + opts["tools"] = fncs_desc + if self._parallel_tool_calls is not None: + opts["parallel_tool_calls"] = self._parallel_tool_calls + + if self._tool_choice is not None: + if isinstance(self._tool_choice, ToolChoice): + # specific function + opts["tool_choice"] = { + "type": "function", + "function": {"name": self._tool_choice.name}, + } + else: + opts["tool_choice"] = self._tool_choice + + if self._llm._opts.metadata is not None: + # some OpenAI-like API doesn't support having a `metadata` field. (Even None) + opts["metadata"] = self._llm._opts.metadata + + if self._llm._opts.store is not None: + opts["store"] = self._llm._opts.store + + user = self._user or openai.NOT_GIVEN + messages = _build_oai_context(self._chat_ctx, id(self)) + stream = await self._client.chat.completions.create( + messages=messages, + model=self._model, + n=self._n, + temperature=self._temperature, + stream_options={"include_usage": True}, + stream=True, + user=user, + **opts, + ) + + async with stream: async for chunk in stream: for choice in chunk.choices: chat_chunk = self._parse_choice(chunk.id, choice) if chat_chunk is not None: + retryable = False self._event_ch.send_nowait(chat_chunk) if chunk.usage is not None: @@ -642,7 +773,7 @@ async def _main_task(self) -> None: ) except openai.APITimeoutError: - raise APITimeoutError() + raise APITimeoutError(retryable=retryable) except openai.APIStatusError as e: raise APIStatusError( e.message, @@ -651,7 +782,7 @@ async def _main_task(self) -> None: body=e.body, ) except Exception as e: - raise APIConnectionError() from e + raise APIConnectionError(retryable=retryable) from e def _parse_choice(self, id: str, choice: Choice) -> llm.ChatChunk | None: delta = choice.delta @@ -713,7 +844,7 @@ def _try_build_function(self, id: str, choice: Choice) -> llm.ChatChunk | None: ) return None - fnc_info = create_ai_function_info( + fnc_info = _create_ai_function_info( self._fnc_ctx, self._tool_call_id, self._fnc_name, self._fnc_raw_arguments ) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py index c2667665d..a2ef15854 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py @@ -45,6 +45,7 @@ CerebrasChatModels = Literal[ "llama3.1-8b", "llama3.1-70b", + "llama-3.3-70b", ] PerplexityChatModels = Literal[ @@ -58,8 +59,8 @@ GroqChatModels = Literal[ "llama-3.1-405b-reasoning", - "llama-3.1-70b-versatile", "llama-3.1-8b-instant", + "llama-3.3-70b-versatile", "llama3-groq-70b-8192-tool-use-preview", "llama3-groq-8b-8192-tool-use-preview", "llama-guard-3-8b", @@ -80,6 +81,7 @@ ] VertexModels = Literal[ + "google/gemini-2.0-flash-exp", "google/gemini-1.5-flash", "google/gemini-1.5-pro", "google/gemini-1.0-pro-vision", @@ -143,6 +145,7 @@ "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "meta-llama/Llama-3.3-70B-Instruct-Turbo", "mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.3", diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py index 6852c3bf6..fbb453609 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py @@ -2,27 +2,27 @@ from .realtime_model import ( DEFAULT_INPUT_AUDIO_TRANSCRIPTION, DEFAULT_SERVER_VAD_OPTIONS, - InputTranscriptionCompleted, - InputTranscriptionFailed, InputTranscriptionOptions, RealtimeContent, + RealtimeError, RealtimeModel, RealtimeOutput, RealtimeResponse, RealtimeSession, + RealtimeSessionOptions, RealtimeToolCall, ServerVadOptions, ) __all__ = [ - "InputTranscriptionCompleted", - "InputTranscriptionFailed", "RealtimeContent", "RealtimeOutput", "RealtimeResponse", "RealtimeToolCall", "RealtimeSession", "RealtimeModel", + "RealtimeError", + "RealtimeSessionOptions", "ServerVadOptions", "InputTranscriptionOptions", "ConversationItemCreated", diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/api_proto.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/api_proto.py index 0a022ad03..2bf9778d3 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/api_proto.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/api_proto.py @@ -16,7 +16,7 @@ class FunctionToolChoice(TypedDict): name: str -Voice = Literal["alloy", "echo", "shimmer"] +Voice = Literal["alloy", "echo", "shimmer", "ash", "ballad", "coral", "sage", "verse"] ToolChoice = Union[Literal["auto", "none", "required"], FunctionToolChoice] Role = Literal["system", "assistant", "user", "tool"] GenerationFinishedReason = Literal["stop", "max_tokens", "content_filter", "interrupt"] @@ -27,6 +27,16 @@ class FunctionToolChoice(TypedDict): "in_progress", "completed", "incomplete", "cancelled", "failed" ] +# https://platform.openai.com/docs/models/gp#gpt-4o-realtime +OpenAIModel = Literal[ + "gpt-4o-realtime-preview", + "gpt-4o-realtime-preview-2024-10-01", + "gpt-4o-realtime-preview-2024-12-17", + "gpt-4o-mini-realtime-preview", + "gpt-4o-mini-realtime-preview-2024-12-17", +] +DefaultOpenAIModel = "gpt-4o-realtime-preview" + class TextContent(TypedDict): type: Literal["text"] @@ -145,6 +155,12 @@ class InputTokenDetails(TypedDict): cached_tokens: int text_tokens: int audio_tokens: int + cached_tokens_details: CachedTokenDetails + + +class CachedTokenDetails(TypedDict): + text_tokens: int + audio_tokens: int class OutputTokenDetails(TypedDict): diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 46151242f..10d7abc1f 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -4,23 +4,26 @@ import base64 import os import time +import weakref from copy import deepcopy from dataclasses import dataclass -from typing import AsyncIterable, Literal, Union, cast, overload +from typing import AsyncIterable, Literal, Optional, Union, cast, overload from urllib.parse import urlencode import aiohttp from livekit import rtc from livekit.agents import llm, utils +from livekit.agents.llm.function_context import _create_ai_function_info from livekit.agents.metrics import MultimodalLLMError, MultimodalLLMMetrics from typing_extensions import TypedDict -from .._oai_api import build_oai_function_description, create_ai_function_info +from .._oai_api import build_oai_function_description from . import api_proto, remote_items from .log import logger EventTypes = Literal[ "start_session", + "session_updated", "error", "input_speech_started", "input_speech_stopped", @@ -103,8 +106,11 @@ class RealtimeToolCall: """id of the tool call""" -# TODO(theomonnom): add the content type directly inside RealtimeContent? -# text/audio/transcript? +@dataclass +class Capabilities: + supports_truncate: bool + + @dataclass class RealtimeContent: response_id: str @@ -142,18 +148,31 @@ class InputTranscriptionOptions: @dataclass -class _ModelOptions: - model: str | None +class RealtimeError: + event_id: str + type: str + message: str + code: Optional[str] + param: Optional[str] + + +@dataclass +class RealtimeSessionOptions: + model: api_proto.OpenAIModel | str modalities: list[api_proto.Modality] instructions: str voice: api_proto.Voice input_audio_format: api_proto.AudioFormat output_audio_format: api_proto.AudioFormat - input_audio_transcription: InputTranscriptionOptions - turn_detection: ServerVadOptions + input_audio_transcription: InputTranscriptionOptions | None + turn_detection: ServerVadOptions | None tool_choice: api_proto.ToolChoice temperature: float max_response_output_tokens: int | Literal["inf"] + + +@dataclass +class _ModelOptions(RealtimeSessionOptions): api_key: str | None base_url: str entra_token: str | None @@ -173,6 +192,7 @@ class _ContentPtr(TypedDict): prefix_padding_ms=300, silence_duration_ms=500, ) + DEFAULT_INPUT_AUDIO_TRANSCRIPTION = InputTranscriptionOptions(model="whisper-1") @@ -183,7 +203,7 @@ def __init__( *, instructions: str = "", modalities: list[api_proto.Modality] = ["text", "audio"], - model: str = "gpt-4o-realtime-preview-2024-10-01", + model: api_proto.OpenAIModel | str = api_proto.DefaultOpenAIModel, voice: api_proto.Voice = "alloy", input_audio_format: api_proto.AudioFormat = "pcm16", output_audio_format: api_proto.AudioFormat = "pcm16", @@ -226,7 +246,7 @@ def __init__( *, instructions: str = "", modalities: list[api_proto.Modality] = ["text", "audio"], - model: str | None = "gpt-4o-realtime-preview-2024-10-01", + model: api_proto.OpenAIModel | str = api_proto.DefaultOpenAIModel, voice: api_proto.Voice = "alloy", input_audio_format: api_proto.AudioFormat = "pcm16", output_audio_format: api_proto.AudioFormat = "pcm16", @@ -268,6 +288,9 @@ def __init__( ValueError: If the API key is not provided and cannot be found in environment variables. """ super().__init__() + self._capabilities = Capabilities( + supports_truncate=True, + ) self._base_url = base_url is_azure = ( @@ -306,7 +329,7 @@ def __init__( ) self._loop = loop or asyncio.get_event_loop() - self._rt_sessions: list[RealtimeSession] = [] + self._rt_sessions = weakref.WeakSet[RealtimeSession]() self._http_session = http_session @classmethod @@ -411,9 +434,13 @@ def _ensure_session(self) -> aiohttp.ClientSession: return self._http_session @property - def sessions(self) -> list[RealtimeSession]: + def sessions(self) -> weakref.WeakSet[RealtimeSession]: return self._rt_sessions + @property + def capabilities(self) -> Capabilities: + return self._capabilities + def session( self, *, @@ -459,7 +486,7 @@ def session( http_session=self._ensure_session(), loop=self._loop, ) - self._rt_sessions.append(new_session) + self._rt_sessions.add(new_session) return new_session async def aclose(self) -> None: @@ -497,10 +524,6 @@ def create( message_content = message.content tool_call_id = message.tool_call_id - if not tool_call_id and message_content is None: - # not a function call while the message content is None - fut.set_result(False) - return fut event: api_proto.ClientEvent.ConversationItemCreate | None = None if tool_call_id: if message.role == "tool": @@ -686,8 +709,94 @@ class Response: def __init__(self, sess: RealtimeSession) -> None: self._sess = sess - def create(self) -> None: - self._sess._queue_msg({"type": "response.create"}) + def create( + self, + *, + on_duplicate: Literal[ + "cancel_existing", "cancel_new", "keep_both" + ] = "keep_both", + ) -> asyncio.Future[bool]: + """Creates a new response. + + Args: + on_duplicate: How to handle when there is an existing response in progress: + - "cancel_existing": Cancel the existing response before creating new one + - "cancel_new": Skip creating new response if one is in progress + - "keep_both": Wait for the existing response to be done and then create a new one + + Returns: + Future that resolves when the response create request is queued + """ + if on_duplicate not in ("cancel_existing", "cancel_new", "keep_both"): + raise ValueError( + "invalid on_duplicate value, must be one of: " + "cancel_existing, cancel_new, keep_both" + ) + + # check if there is a pending response creation request sent + pending_create_fut = self._sess._response_create_fut + if pending_create_fut is not None: + if on_duplicate == "cancel_new": + logger.warning( + "skip new response creation due to previous pending response creation", + extra=self._sess.logging_extra(), + ) + _fut = asyncio.Future[bool]() + _fut.set_result(False) + return _fut + + active_resp_id = self._sess._active_response_id + _logging_extra = { + "existing_response_id": active_resp_id, + **self._sess.logging_extra(), + } + + if ( + not active_resp_id + or self._sess._pending_responses[active_resp_id].done_fut.done() + ): + # no active response in progress, create a new one + self._sess._queue_msg({"type": "response.create"}) + _fut = asyncio.Future[bool]() + _fut.set_result(True) + return _fut + + # there is an active response in progress + if on_duplicate == "cancel_new": + logger.warning( + "skip new response creation due to active response in progress", + extra=_logging_extra, + ) + _fut = asyncio.Future[bool]() + _fut.set_result(False) + return _fut + + if on_duplicate == "cancel_existing": + self.cancel() + logger.warning( + "cancelling in-progress response to create a new one", + extra=_logging_extra, + ) + elif on_duplicate == "keep_both": + logger.warning( + "waiting for in-progress response to be done " + "before creating a new one", + extra=_logging_extra, + ) + + # create a task to wait for the previous response and then create new one + async def wait_and_create() -> bool: + await self._sess._pending_responses[active_resp_id].done_fut + logger.info( + "in-progress response is done, creating a new one", + extra=_logging_extra, + ) + new_create_fut = asyncio.Future[None]() + self._sess._response_create_fut = new_create_fut + self._sess._queue_msg({"type": "response.create"}) + return True + + return asyncio.create_task(wait_and_create()) def cancel(self) -> None: self._sess._queue_msg({"type": "response.cancel"}) @@ -707,7 +816,7 @@ def __init__( self._main_task(), name="openai-realtime-session" ) # manage conversation items internally - self._remote_converstation_items = remote_items._RemoteConversationItems() + self._remote_conversation_items = remote_items._RemoteConversationItems() # wait for the item to be created or deleted self._item_created_futs: dict[str, asyncio.Future[bool]] = {} @@ -722,6 +831,8 @@ def __init__( self._http_session = http_session self._pending_responses: dict[str, RealtimeResponse] = {} + self._active_response_id: str | None = None + self._response_create_fut: asyncio.Future[None] | None = None self._session_id = "not-connected" self.session_update() # initial session init @@ -754,6 +865,9 @@ def conversation(self) -> Conversation: def input_audio_buffer(self) -> InputAudioBuffer: return RealtimeSession.InputAudioBuffer(self) + def _push_audio(self, frame: rtc.AudioFrame) -> None: + self.input_audio_buffer.append(frame) + @property def response(self) -> Response: return RealtimeSession.Response(self) @@ -803,12 +917,19 @@ def session_update( function_data["type"] = "function" tools.append(function_data) - server_vad_opts: api_proto.ServerVad = { - "type": "server_vad", - "threshold": self._opts.turn_detection.threshold, - "prefix_padding_ms": self._opts.turn_detection.prefix_padding_ms, - "silence_duration_ms": self._opts.turn_detection.silence_duration_ms, - } + server_vad_opts: api_proto.ServerVad | None = None + if self._opts.turn_detection is not None: + server_vad_opts = { + "type": "server_vad", + "threshold": self._opts.turn_detection.threshold, + "prefix_padding_ms": self._opts.turn_detection.prefix_padding_ms, + "silence_duration_ms": self._opts.turn_detection.silence_duration_ms, + } + input_audio_transcription_opts: api_proto.InputAudioTranscription | None = None + if self._opts.input_audio_transcription is not None: + input_audio_transcription_opts = { + "model": self._opts.input_audio_transcription.model, + } session_data: api_proto.ClientEvent.SessionUpdateData = { "modalities": self._opts.modalities, @@ -816,9 +937,7 @@ def session_update( "voice": self._opts.voice, "input_audio_format": self._opts.input_audio_format, "output_audio_format": self._opts.output_audio_format, - "input_audio_transcription": { - "model": self._opts.input_audio_transcription.model, - }, + "input_audio_transcription": input_audio_transcription_opts, "turn_detection": server_vad_opts, "tools": tools, "tool_choice": self._opts.tool_choice, @@ -844,7 +963,7 @@ def session_update( ) def chat_ctx_copy(self) -> llm.ChatContext: - return self._remote_converstation_items.to_chat_context() + return self._remote_conversation_items.to_chat_context() async def set_chat_ctx(self, new_ctx: llm.ChatContext) -> None: """Sync the chat context with the agent's chat context. @@ -852,10 +971,16 @@ async def set_chat_ctx(self, new_ctx: llm.ChatContext) -> None: Compute the minimum number of insertions and deletions to transform the old chat context messages to the new chat context messages. """ - original_ctx = self._remote_converstation_items.to_chat_context() + original_ctx = self._remote_conversation_items.to_chat_context() + # filter out messages that are not function calls and content is None + filtered_messages = [ + msg + for msg in new_ctx.messages + if msg.tool_call_id or msg.content is not None + ] changes = utils._compute_changes( - original_ctx.messages, new_ctx.messages, key_fnc=lambda x: x.id + original_ctx.messages, filtered_messages, key_fnc=lambda x: x.id ) logger.debug( "sync chat context", @@ -871,24 +996,8 @@ async def set_chat_ctx(self, new_ctx: llm.ChatContext) -> None: if changes.to_add and not any( isinstance(msg.content, llm.ChatAudio) for _, msg in changes.to_add ): - # Patch: add an empty audio message to the chat context - # to set the API in audio mode - data = b"\x00\x00" * api_proto.SAMPLE_RATE - _empty_audio = rtc.AudioFrame( - data=data, - sample_rate=api_proto.SAMPLE_RATE, - num_channels=api_proto.NUM_CHANNELS, - samples_per_channel=len(data) // 2, - ) - changes.to_add.append( - ( - None, - llm.ChatMessage( - role="user", content=llm.ChatAudio(frame=_empty_audio) - ), - ) - ) - logger.debug("added empty audio message to the chat context") + # Patch: append an empty audio message to set the API in audio mode + changes.to_add.append((None, self._create_empty_user_audio_message(1.0))) _futs = [ self.conversation.item.delete(item_id=msg.id) for msg in changes.to_delete @@ -900,10 +1009,47 @@ async def set_chat_ctx(self, new_ctx: llm.ChatContext) -> None: # wait for all the futures to complete await asyncio.gather(*_futs) - def _update_converstation_item_content( + def _create_empty_user_audio_message(self, duration: float) -> llm.ChatMessage: + """Create an empty audio message with the given duration.""" + samples = int(duration * api_proto.SAMPLE_RATE) + return llm.ChatMessage( + role="user", + content=llm.ChatAudio( + frame=rtc.AudioFrame( + data=b"\x00\x00" * (samples * api_proto.NUM_CHANNELS), + sample_rate=api_proto.SAMPLE_RATE, + num_channels=api_proto.NUM_CHANNELS, + samples_per_channel=samples, + ) + ), + ) + + def _recover_from_text_response(self, item_id: str | None = None) -> None: + """Try to recover from a text response to audio mode. + + Sometimes the OpenAI Realtime API returns text instead of audio responses. + This method tries to recover from this by requesting a new response after + deleting the text response and creating an empty user audio message. + """ + if item_id: + # remove the text response if needed + self.conversation.item.delete(item_id=item_id) + self.conversation.item.create(self._create_empty_user_audio_message(1.0)) + self.response.create(on_duplicate="keep_both") + + def _truncate_conversation_item( + self, item_id: str, content_index: int, audio_end_ms: int + ) -> None: + self.conversation.item.truncate( + item_id=item_id, + content_index=content_index, + audio_end_ms=audio_end_ms, + ) + + def _update_conversation_item_content( self, item_id: str, content: llm.ChatContent | list[llm.ChatContent] | None ) -> None: - item = self._remote_converstation_items.get(item_id) + item = self._remote_conversation_items.get(item_id) if item is None: logger.warning( "conversation item not found, skipping update", @@ -993,6 +1139,8 @@ async def _recv_task(): if event == "session.created": self._handle_session_created(data) + if event == "session.updated": + self._handle_session_updated(data) elif event == "error": self._handle_error(data) elif event == "input_audio_buffer.speech_started": @@ -1029,6 +1177,8 @@ async def _recv_task(): self._handle_response_audio_transcript_delta(data) elif event == "response.audio.done": self._handle_response_audio_done(data) + elif event == "response.text.done": + self._handle_response_text_done(data) elif event == "response.audio_transcript.done": self._handle_response_audio_transcript_done(data) elif event == "response.content_part.done": @@ -1059,12 +1209,59 @@ def _handle_session_created( ): self._session_id = session_created["session"]["id"] + def _handle_session_updated( + self, session_updated: api_proto.ServerEvent.SessionUpdated + ): + session = session_updated["session"] + if session["turn_detection"] is None: + turn_detection = None + else: + turn_detection = ServerVadOptions( + threshold=session["turn_detection"]["threshold"], + prefix_padding_ms=session["turn_detection"]["prefix_padding_ms"], + silence_duration_ms=session["turn_detection"]["silence_duration_ms"], + ) + if session["input_audio_transcription"] is None: + input_audio_transcription = None + else: + input_audio_transcription = InputTranscriptionOptions( + model=session["input_audio_transcription"]["model"], + ) + + self.emit( + "session_updated", + RealtimeSessionOptions( + model=session["model"], + modalities=session["modalities"], + instructions=session["instructions"], + voice=session["voice"], + input_audio_format=session["input_audio_format"], + output_audio_format=session["output_audio_format"], + input_audio_transcription=input_audio_transcription, + turn_detection=turn_detection, + tool_choice=session["tool_choice"], + temperature=session["temperature"], + max_response_output_tokens=session["max_response_output_tokens"], + ), + ) + def _handle_error(self, error: api_proto.ServerEvent.Error): logger.error( "OpenAI S2S error %s", error, extra=self.logging_extra(), ) + error_content = error["error"] + self.emit( + "error", + RealtimeError( + event_id=error["event_id"], + type=error_content["type"], + message=error_content["message"], + code=error_content.get("code"), + param=error_content.get("param"), + ), + ) def _handle_input_audio_buffer_speech_started( self, speech_started: api_proto.ServerEvent.InputAudioBufferSpeechStarted @@ -1171,7 +1368,7 @@ def _handle_conversation_item_created( return # Insert into conversation items - self._remote_converstation_items.insert_after(previous_item_id, message) + self._remote_conversation_items.insert_after(previous_item_id, message) if item_id in self._item_created_futs: self._item_created_futs[item_id].set_result(True) del self._item_created_futs[item_id] @@ -1182,7 +1379,7 @@ def _handle_conversation_item_deleted( ): # Delete from conversation items item_id = item_deleted["item_id"] - self._remote_converstation_items.delete(item_id) + self._remote_conversation_items.delete(item_id) if item_id in self._item_deleted_futs: self._item_deleted_futs[item_id].set_result(True) del self._item_deleted_futs[item_id] @@ -1212,6 +1409,13 @@ def _handle_response_created( _created_timestamp=time.time(), ) self._pending_responses[new_response.id] = new_response + self._active_response_id = new_response.id + + # complete the create future if it exists + if self._response_create_fut is not None: + self._response_create_fut.set_result(None) + self._response_create_fut = None + self.emit("response_created", new_response) def _handle_response_output_item_added( @@ -1301,6 +1505,12 @@ def _handle_response_audio_done( assert isinstance(content.audio_stream, utils.aio.Chan) content.audio_stream.close() + def _handle_response_text_done( + self, response_text_done: api_proto.ServerEvent.ResponseTextDone + ): + content = self._get_content(response_text_done) + content.text = response_text_done["text"] + def _handle_response_audio_transcript_done( self, response_audio_transcript_done: api_proto.ServerEvent.ResponseAudioTranscriptDone, @@ -1335,14 +1545,14 @@ def _handle_response_output_item_done( item = response_output_done["item"] assert item["type"] == "function_call" - fnc_call_info = create_ai_function_info( + fnc_call_info = _create_ai_function_info( self._fnc_ctx, item["call_id"], item["name"], item["arguments"], ) - msg = self._remote_converstation_items.get(output.item_id) + msg = self._remote_conversation_items.get(output.item_id) if msg is not None: # update the content of the message assert msg.tool_call_id == item["call_id"] @@ -1363,6 +1573,7 @@ def _handle_response_done(self, response_done: api_proto.ServerEvent.ResponseDon response_data = response_done["response"] response_id = response_data["id"] response = self._pending_responses[response_id] + self._active_response_id = None response.done_fut.set_result(None) response.status = response_data["status"] @@ -1412,6 +1623,7 @@ def _handle_response_done(self, response_done: api_proto.ServerEvent.ResponseDon duration = time.time() - response._created_timestamp usage = response.usage or {} # type: ignore + input_token_details = usage.get("input_token_details", {}) metrics = MultimodalLLMMetrics( timestamp=response._created_timestamp, request_id=response.id, @@ -1425,13 +1637,19 @@ def _handle_response_done(self, response_done: api_proto.ServerEvent.ResponseDon tokens_per_second=usage.get("output_tokens", 0) / duration, error=metrics_error, input_token_details=MultimodalLLMMetrics.InputTokenDetails( - cached_tokens=usage.get("input_token_details", {}).get( - "cached_tokens", 0 - ), + cached_tokens=input_token_details.get("cached_tokens", 0), text_tokens=usage.get("input_token_details", {}).get("text_tokens", 0), audio_tokens=usage.get("input_token_details", {}).get( "audio_tokens", 0 ), + cached_tokens_details=MultimodalLLMMetrics.CachedTokenDetails( + text_tokens=input_token_details.get( + "cached_tokens_details", {} + ).get("text_tokens", 0), + audio_tokens=input_token_details.get( + "cached_tokens_details", {} + ).get("audio_tokens", 0), + ), ), output_token_details=MultimodalLLMMetrics.OutputTokenDetails( text_tokens=usage.get("output_token_details", {}).get("text_tokens", 0), @@ -1461,17 +1679,22 @@ async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str) await called_fnc.task tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc) - - if called_fnc.result is not None: + logger.info( + "creating response for tool call", + extra={ + "function": fnc_call_info.function_info.name, + }, + ) + if tool_call.content is not None: create_fut = self.conversation.item.create( tool_call, previous_item_id=item_id, ) - self.response.create() + await self.response.create(on_duplicate="keep_both") await create_fut # update the message with the tool call result - msg = self._remote_converstation_items.get(tool_call.id) + msg = self._remote_conversation_items.get(tool_call.id) if msg is not None: assert msg.tool_call_id == tool_call.tool_call_id assert msg.role == "tool" diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py index 4b79ba038..e3f19972a 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py @@ -15,18 +15,17 @@ from __future__ import annotations import dataclasses -import io import os -import wave from dataclasses import dataclass import httpx +from livekit import rtc from livekit.agents import ( APIConnectionError, + APIConnectOptions, APIStatusError, APITimeoutError, stt, - utils, ) from livekit.agents.utils import AudioBuffer @@ -73,6 +72,7 @@ def __init__( ) self._client = client or openai.AsyncClient( + max_retries=0, api_key=api_key, base_url=base_url, http_client=httpx.AsyncClient( @@ -86,6 +86,15 @@ def __init__( ), ) + def update_options( + self, + *, + model: WhisperModels | GroqAudioModels | None = None, + language: str | None = None, + ) -> None: + self._opts.model = model or self._opts.model + self._opts.language = language or self._opts.language + @staticmethod def with_groq( *, @@ -103,12 +112,10 @@ def with_groq( the ``GROQ_API_KEY`` environmental variable. """ - # Use environment variable if API key is not provided api_key = api_key or os.environ.get("GROQ_API_KEY") if api_key is None: raise ValueError("Groq API key is required") - # Instantiate and return a configured STT instance return STT( model=model, api_key=api_key, @@ -124,29 +131,35 @@ def _sanitize_options(self, *, language: str | None = None) -> _STTOptions: return config async def _recognize_impl( - self, buffer: AudioBuffer, *, language: str | None = None + self, + buffer: AudioBuffer, + *, + language: str | None, + conn_options: APIConnectOptions, ) -> stt.SpeechEvent: try: config = self._sanitize_options(language=language) - buffer = utils.merge_frames(buffer) - io_buffer = io.BytesIO() - with wave.open(io_buffer, "wb") as wav: - wav.setnchannels(buffer.num_channels) - wav.setsampwidth(2) # 16-bit - wav.setframerate(buffer.sample_rate) - wav.writeframes(buffer.data) - + data = rtc.combine_audio_frames(buffer).to_wav_bytes() resp = await self._client.audio.transcriptions.create( - file=("file.wav", io_buffer.getvalue(), "audio/wav"), + file=( + "file.wav", + data, + "audio/wav", + ), model=self._opts.model, language=config.language, - response_format="json", + # verbose_json returns language and other details + response_format="verbose_json", + timeout=httpx.Timeout(30, connect=conn_options.timeout), ) return stt.SpeechEvent( type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[ - stt.SpeechData(text=resp.text or "", language=language or "") + stt.SpeechData( + text=resp.text or "", + language=resp.language or config.language or "", + ) ], ) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py index 3fac85cc1..ce7741eb8 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/tts.py @@ -76,6 +76,7 @@ def __init__( ) self._client = client or openai.AsyncClient( + max_retries=0, api_key=api_key, base_url=base_url, http_client=httpx.AsyncClient( @@ -123,6 +124,7 @@ def create_azure_client( """ azure_client = openai.AsyncAzureOpenAI( + max_retries=0, azure_endpoint=azure_endpoint, azure_deployment=azure_deployment, api_version=api_version, diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/utils.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/utils.py index 40d95037f..278c499a1 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/utils.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import json import os from typing import Any, Awaitable, Callable, Optional, Union @@ -25,6 +26,8 @@ def build_oai_message(msg: llm.ChatMessage, cache_key: Any): # add content if provided if isinstance(msg.content, str): oai_msg["content"] = msg.content + elif isinstance(msg.content, dict): + oai_msg["content"] = json.dumps(msg.content) elif isinstance(msg.content, list): oai_content: list[dict[str, Any]] = [] for cnt in msg.content: @@ -64,7 +67,7 @@ def _build_oai_image_content(image: llm.ChatImage, cache_key: Any): if isinstance(image.image, str): # image url return { "type": "image_url", - "image_url": {"url": image.image, "detail": "auto"}, + "image_url": {"url": image.image, "detail": image.inference_detail}, } elif isinstance(image.image, rtc.VideoFrame): # VideoFrame if cache_key not in image._cache: @@ -75,7 +78,7 @@ def _build_oai_image_content(image: llm.ChatImage, cache_key: Any): opts.resize_options = utils.images.ResizeOptions( width=image.inference_width, height=image.inference_height, - strategy="center_aspect_fit", + strategy="scale_aspect_fit", ) encoded_data = utils.images.encode(image.image, opts) @@ -83,7 +86,12 @@ def _build_oai_image_content(image: llm.ChatImage, cache_key: Any): return { "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image._cache[cache_key]}"}, + "image_url": { + "url": f"data:image/jpeg;base64,{image._cache[cache_key]}", + "detail": image.inference_detail, + }, } - raise ValueError(f"unknown image type {type(image.image)}") + raise ValueError( + "LiveKit OpenAI Plugin: ChatImage must be an rtc.VideoFrame or a URL" + ) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/version.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/version.py index b55d1d86b..c1fcb43b8 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/version.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.10.7" +__version__ = "0.10.13" diff --git a/livekit-plugins/livekit-plugins-openai/package.json b/livekit-plugins/livekit-plugins-openai/package.json index 8019c50f3..e23704cba 100644 --- a/livekit-plugins/livekit-plugins-openai/package.json +++ b/livekit-plugins/livekit-plugins-openai/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-openai", "private": true, - "version": "0.10.7" + "version": "0.10.13" } diff --git a/livekit-plugins/livekit-plugins-openai/setup.py b/livekit-plugins/livekit-plugins-openai/setup.py index a7b6cdf19..eb9d6d0fe 100644 --- a/livekit-plugins/livekit-plugins-openai/setup.py +++ b/livekit-plugins/livekit-plugins-openai/setup.py @@ -48,7 +48,7 @@ packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", install_requires=[ - "livekit-agents[codecs, images]>=0.11", + "livekit-agents[codecs, images]>=0.12.3", "openai>=1.50", ], extras_require={ diff --git a/livekit-plugins/livekit-plugins-playai/CHANGELOG.md b/livekit-plugins/livekit-plugins-playai/CHANGELOG.md new file mode 100644 index 000000000..8fd61d2cf --- /dev/null +++ b/livekit-plugins/livekit-plugins-playai/CHANGELOG.md @@ -0,0 +1,31 @@ +# livekit-plugins-playht + +## 1.0.4 + +### Patch Changes + +- Support PlayAI TTS engine. - [#1174](https://github.com/livekit/agents/pull/1174) ([@jayeshp19](https://github.com/jayeshp19)) + +## 1.0.3 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 1.0.2 + +### Patch Changes + +- fix(playht): add sample_rate parameter to JSON payload - [#1141](https://github.com/livekit/agents/pull/1141) ([@imsakg](https://github.com/imsakg)) + +- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom)) + +- feat(playht): add Play3.0-mini engine support - [#1140](https://github.com/livekit/agents/pull/1140) ([@imsakg](https://github.com/imsakg)) + +## 1.0.1 + +### Patch Changes + +- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom)) + +- expose usage metrics - [#984](https://github.com/livekit/agents/pull/984) ([@theomonnom](https://github.com/theomonnom)) diff --git a/livekit-plugins/livekit-plugins-playai/README.md b/livekit-plugins/livekit-plugins-playai/README.md new file mode 100644 index 000000000..5561dbe66 --- /dev/null +++ b/livekit-plugins/livekit-plugins-playai/README.md @@ -0,0 +1,13 @@ +# LiveKit Plugins PlayAI/PlayHT + +Agent Framework plugin for voice synthesis with [PlayAI](https://play.ai/) API. + +## Installation + +```bash +pip install livekit-plugins-playai +``` + +## Pre-requisites + +You'll need USER ID and API Secret KEY from PlayHT. It can be set as an environment variable: `PLAYHT_USER_ID`, `PLAYHT_API_KEY` get it from [here](https://play.ht/studio/api-access) diff --git a/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/__init__.py b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/__init__.py similarity index 58% rename from livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/__init__.py rename to livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/__init__.py index 82229c316..033d9363e 100644 --- a/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/__init__.py +++ b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/__init__.py @@ -1,27 +1,20 @@ -from .models import TTSEngines -from .tts import DEFAULT_VOICE, TTS, Voice +from .tts import TTS from .version import __version__ __all__ = [ "TTS", - "Voice", - "DEFAULT_VOICE", - "TTSEngines", "__version__", ] from livekit.agents import Plugin -class PlayHTPlugin(Plugin): +class PlayAIPlugin(Plugin): def __init__(self) -> None: super().__init__(__name__, __version__, __package__) - def download_files(self) -> None: - self.download_files(self) - -Plugin.register_plugin(PlayHTPlugin()) +Plugin.register_plugin(PlayAIPlugin()) # Cleanup docs of unexported modules _module = dir() diff --git a/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/log.py b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/log.py new file mode 100644 index 000000000..decd14a99 --- /dev/null +++ b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/log.py @@ -0,0 +1,5 @@ +import logging + +logger = logging.getLogger("livekit.plugins.playai") +# suppress verbose websocket logs +logging.getLogger("websockets.client").setLevel(logging.INFO) diff --git a/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/models.py b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/models.py new file mode 100644 index 000000000..1dc6dfce8 --- /dev/null +++ b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/models.py @@ -0,0 +1,9 @@ +from typing import Literal + +from pyht.client import Format # type: ignore + +TTSModel = Literal["Play3.0-mini-ws", "PlayDialog-ws", "Play3.0-mini", "PlayDialog"] +FORMAT = Literal["mp3"] +format_mapping = { + "mp3": Format.FORMAT_MP3, +} diff --git a/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/py.typed b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/tts.py b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/tts.py new file mode 100644 index 000000000..464f3f418 --- /dev/null +++ b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/tts.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +import asyncio +import os +import weakref +from dataclasses import dataclass, fields + +from livekit import rtc +from livekit.agents import ( + DEFAULT_API_CONNECT_OPTIONS, + APIConnectionError, + APIConnectOptions, + tokenize, + tts, + utils, +) +from pyht import AsyncClient as PlayHTAsyncClient # type: ignore +from pyht.client import Format, Language, TTSOptions # type: ignore + +from .log import logger +from .models import TTSModel + +NUM_CHANNELS = 1 + + +@dataclass +class _Options: + model: TTSModel | str + tts_options: TTSOptions + word_tokenizer: tokenize.WordTokenizer + + +class TTS(tts.TTS): + def __init__( + self, + *, + api_key: str | None = None, + user_id: str | None = None, + voice: str = "s3://voice-cloning-zero-shot/d9ff78ba-d016-47f6-b0ef-dd630f59414e/female-cs/manifest.json", + language: str = "english", + sample_rate: int = 24000, + model: TTSModel | str = "Play3.0-mini-ws", + word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer( + ignore_punctuation=False + ), + **kwargs, + ) -> None: + """ + Initialize the PlayAI TTS engine. + + Args: + api_key (str): PlayAI API key. + user_id (str): PlayAI user ID. + voice (str): Voice manifest URL. + model (TTSModel): TTS model, defaults to "Play3.0-mini-ws". + language (str): language, defaults to "english". + sample_rate (int): sample rate (Hz), A number greater than or equal to 8000, and must be less than or equal to 48000 + word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer. + **kwargs: Additional options. + """ + + super().__init__( + capabilities=tts.TTSCapabilities( + streaming=False, + ), + sample_rate=sample_rate, + num_channels=1, + ) + + api_key = api_key or os.environ.get("PLAYHT_API_KEY") + user_id = user_id or os.environ.get("PLAYHT_USER_ID") + + if not api_key or not user_id: + raise ValueError( + "PlayHT API key and user ID are required. Set environment variables PLAYHT_API_KEY and PLAYHT_USER_ID or pass them explicitly." + ) + _validate_kwargs(kwargs) + self._config = TTSOptions( + voice=voice, + format=Format.FORMAT_MP3, # Default format for now + sample_rate=sample_rate, + language=Language(language), + **kwargs, + ) + + self._opts = _Options( + model=model, + tts_options=self._config, + word_tokenizer=word_tokenizer, + ) + + # Initialize client + self._client = PlayHTAsyncClient( + user_id=user_id, + api_key=api_key, + ) + self._streams = weakref.WeakSet[SynthesizeStream]() + + def update_options( + self, + *, + voice: str | None = None, + model: TTSModel | str | None = None, + language: str | None = None, + **kwargs, + ) -> None: + """ + Update the TTS options. + """ + updates = {} + if voice is not None: + updates["voice"] = voice + if language is not None: + updates["language"] = Language(language) + tts_kwargs = {k: v for k, v in kwargs.items()} + + self._config = _update_options(self._config, **updates, **tts_kwargs) + + if model is not None: + self._opts.model = model + + for stream in self._streams: + stream._config = _update_options(stream._config, **updates, **tts_kwargs) + if model is not None: + stream._opts.model = model + + def synthesize( + self, + text: str, + *, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> "ChunkedStream": + return ChunkedStream( + tts=self, + input_text=text, + conn_options=conn_options, + opts=self._opts, + ) + + def stream( + self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS + ) -> "SynthesizeStream": + stream = SynthesizeStream( + tts=self, + conn_options=conn_options, + opts=self._opts, + ) + self._streams.add(stream) + return stream + + +class ChunkedStream(tts.ChunkedStream): + def __init__( + self, + *, + tts: TTS, + input_text: str, + conn_options: APIConnectOptions, + opts: _Options, + ) -> None: + super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) + self._client = tts._client + self._opts = opts + self._config = self._opts.tts_options + self._mp3_decoder = utils.codecs.Mp3StreamDecoder() + + async def _run(self) -> None: + request_id = utils.shortuuid() + bstream = utils.audio.AudioByteStream( + sample_rate=self._config.sample_rate, num_channels=NUM_CHANNELS + ) + + try: + async for chunk in self._client.tts( + text=self._input_text, + options=self._config, + voice_engine=self._opts.model, + streaming=True, + ): + for frame in self._mp3_decoder.decode_chunk(chunk): + for frame in bstream.write(frame.data.tobytes()): + self._event_ch.send_nowait( + tts.SynthesizedAudio( + request_id=request_id, + frame=frame, + ) + ) + for frame in bstream.flush(): + self._event_ch.send_nowait( + tts.SynthesizedAudio(request_id=request_id, frame=frame) + ) + except Exception as e: + raise APIConnectionError() from e + + +class SynthesizeStream(tts.SynthesizeStream): + def __init__( + self, + *, + tts: TTS, + conn_options: APIConnectOptions, + opts: _Options, + ): + super().__init__(tts=tts, conn_options=conn_options) + self._client = tts._client + self._opts = opts + self._config = self._opts.tts_options + self._segments_ch = utils.aio.Chan[tokenize.WordStream]() + self._mp3_decoder = utils.codecs.Mp3StreamDecoder() + + async def _run(self) -> None: + request_id = utils.shortuuid() + segment_id = utils.shortuuid() + bstream = utils.audio.AudioByteStream( + sample_rate=self._config.sample_rate, + num_channels=NUM_CHANNELS, + ) + last_frame: rtc.AudioFrame | None = None + + def _send_last_frame(*, segment_id: str, is_final: bool) -> None: + nonlocal last_frame + if last_frame is not None: + self._event_ch.send_nowait( + tts.SynthesizedAudio( + request_id=request_id, + segment_id=segment_id, + frame=last_frame, + is_final=is_final, + ) + ) + last_frame = None + + input_task = asyncio.create_task(self._tokenize_input()) + try: + text_stream = await self._create_text_stream() + async for chunk in self._client.stream_tts_input( + text_stream=text_stream, + options=self._config, + voice_engine=self._opts.model, + ): + for frame in self._mp3_decoder.decode_chunk(chunk): + for frame in bstream.write(frame.data.tobytes()): + _send_last_frame(segment_id=segment_id, is_final=False) + last_frame = frame + + for frame in bstream.flush(): + _send_last_frame(segment_id=segment_id, is_final=False) + last_frame = frame + _send_last_frame(segment_id=segment_id, is_final=True) + except Exception as e: + raise APIConnectionError() from e + finally: + await utils.aio.gracefully_cancel(input_task) + self._client.close() + + @utils.log_exceptions(logger=logger) + async def _tokenize_input(self): + # Converts incoming text into WordStreams and sends them into _segments_ch + word_stream = None + async for input in self._input_ch: + if isinstance(input, str): + if word_stream is None: + word_stream = self._opts.word_tokenizer.stream() + self._segments_ch.send_nowait(word_stream) + word_stream.push_text(input) + elif isinstance(input, self._FlushSentinel): + if word_stream: + word_stream.end_input() + word_stream = None + self._segments_ch.close() + + @utils.log_exceptions(logger=logger) + async def _create_text_stream(self): + async def text_stream(): + async for word_stream in self._segments_ch: + async for word in word_stream: + yield word.token + + return text_stream() + + +def _update_options(config: TTSOptions, **kwargs) -> TTSOptions: + _validate_kwargs(kwargs) + for k, v in kwargs.items(): + if v is not None: + setattr(config, k, v) + return config + + +def _validate_kwargs(kwargs: dict) -> None: + valid_keys = {field.name for field in fields(TTSOptions)} + invalid_keys = set(kwargs.keys()) - valid_keys + if invalid_keys: + raise ValueError( + f"Invalid parameters: {invalid_keys}. Allowed parameters: {valid_keys}" + ) diff --git a/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/version.py b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/version.py new file mode 100644 index 000000000..92192eed4 --- /dev/null +++ b/livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/version.py @@ -0,0 +1 @@ +__version__ = "1.0.4" diff --git a/livekit-plugins/livekit-plugins-playai/package.json b/livekit-plugins/livekit-plugins-playai/package.json new file mode 100644 index 000000000..a4879d16b --- /dev/null +++ b/livekit-plugins/livekit-plugins-playai/package.json @@ -0,0 +1,5 @@ +{ + "name": "livekit-plugins-playai", + "private": true, + "version": "1.0.4" +} diff --git a/livekit-plugins/livekit-plugins-playht/pyproject.toml b/livekit-plugins/livekit-plugins-playai/pyproject.toml similarity index 100% rename from livekit-plugins/livekit-plugins-playht/pyproject.toml rename to livekit-plugins/livekit-plugins-playai/pyproject.toml diff --git a/livekit-plugins/livekit-plugins-playht/setup.py b/livekit-plugins/livekit-plugins-playai/setup.py similarity index 83% rename from livekit-plugins/livekit-plugins-playht/setup.py rename to livekit-plugins/livekit-plugins-playai/setup.py index ea5c7bf77..76c2d2ba5 100644 --- a/livekit-plugins/livekit-plugins-playht/setup.py +++ b/livekit-plugins/livekit-plugins-playai/setup.py @@ -6,14 +6,14 @@ here = pathlib.Path(__file__).parent.resolve() about = {} -with open(os.path.join(here, "livekit", "plugins", "playht", "version.py"), "r") as f: +with open(os.path.join(here, "livekit", "plugins", "playai", "version.py"), "r") as f: exec(f.read(), about) setuptools.setup( - name="livekit-plugins-playht", + name="livekit-plugins-playai", version=about["__version__"], - description="Agent Framework plugin for voice synthesis with PlayHT's API.", + description="Agent Framework plugin for voice synthesis with PlayAI's API.", long_description=(here / "README.md").read_text(encoding="utf-8"), long_description_content_type="text/markdown", url="https://github.com/livekit/agents", @@ -27,17 +27,17 @@ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3 :: Only", ], - keywords=["webrtc", "realtime", "audio", "livekit", "playHT"], + keywords=["webrtc", "realtime", "audio", "livekit", "playHT", "playAI"], license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", install_requires=[ - "livekit-agents[codecs]>=0.11", - "pyht", + "livekit-agents[codecs]>=0.12.3", + "pyht>=0.1.10", "aiohttp", "livekit", ], - package_data={"livekit.plugins.playht": ["py.typed"]}, + package_data={"livekit.plugins.playai": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io", "Website": "https://livekit.io/", diff --git a/livekit-plugins/livekit-plugins-playht/CHANGELOG.md b/livekit-plugins/livekit-plugins-playht/CHANGELOG.md deleted file mode 100644 index 92030c7fd..000000000 --- a/livekit-plugins/livekit-plugins-playht/CHANGELOG.md +++ /dev/null @@ -1,9 +0,0 @@ -# livekit-plugins-playht - -## 1.0.1 - -### Patch Changes - -- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom)) - -- expose usage metrics - [#984](https://github.com/livekit/agents/pull/984) ([@theomonnom](https://github.com/theomonnom)) diff --git a/livekit-plugins/livekit-plugins-playht/README.md b/livekit-plugins/livekit-plugins-playht/README.md deleted file mode 100644 index 53badc144..000000000 --- a/livekit-plugins/livekit-plugins-playht/README.md +++ /dev/null @@ -1,13 +0,0 @@ -# LiveKit Plugins PlayHT - -Agent Framework plugin for voice synthesis with [PlayHT](https://play.ht/) API. - -## Installation - -```bash -pip install livekit-plugins-playht -``` - -## Pre-requisites - -You'll need USER ID and API Secret KEY from PlayHT. It can be set as an environment variable: `PLAYHT_USER_ID`, `PLAYHT_API_KEY` \ No newline at end of file diff --git a/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/log.py b/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/log.py deleted file mode 100644 index 18a81836e..000000000 --- a/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/log.py +++ /dev/null @@ -1,3 +0,0 @@ -import logging - -logger = logging.getLogger("livekit.custom_tts_plugins.playht") diff --git a/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/models.py b/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/models.py deleted file mode 100644 index f872f1601..000000000 --- a/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/models.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Literal - -TTSEngines = Literal["PlayHT2.0", "PlayHT1.0", "PlayHT2.0-turbo"] - -TTSEncoding = Literal[ - "mp3_22050_32", - "mp3_44100_32", - "mp3_44100_64", - "mp3_44100_96", - "mp3_44100_128", - "mp3_44100_192", - "pcm_16000", - "pcm_22050", - "pcm_44100", -] diff --git a/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/tts.py b/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/tts.py deleted file mode 100644 index c411c8315..000000000 --- a/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/tts.py +++ /dev/null @@ -1,234 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -from dataclasses import dataclass -from typing import Any, List, Literal - -import aiohttp -from livekit.agents import ( - DEFAULT_API_CONNECT_OPTIONS, - APIConnectionError, - APIConnectOptions, - APIStatusError, - APITimeoutError, - tts, - utils, -) - -from .log import logger -from .models import TTSEncoding, TTSEngines - -_Encoding = Literal["mp3", "pcm"] - - -def _sample_rate_from_format(output_format: TTSEncoding) -> int: - split = output_format.split("_") - return int(split[1]) - - -def _encoding_from_format(output_format: TTSEncoding) -> _Encoding: - if output_format.startswith("mp3"): - return "mp3" - elif output_format.startswith("pcm"): - return "pcm" - - raise ValueError(f"Unknown format: {output_format}") - - -@dataclass -class Voice: - id: str - name: str - voice_engine: TTSEngines - - -DEFAULT_VOICE = Voice( - id="s3://peregrine-voices/mel22/manifest.json", - name="Will", - voice_engine="PlayHT2.0", -) - -ACCEPT_HEADER = { - "mp3": "audio/mpeg", - "wav": "audio/wav", - "ogg": "audio/ogg", - "flac": "audio/flac", - "mulaw": "audio/basic", # commonly used for mulaw -} - -API_BASE_URL_V2 = "https://api.play.ht/api/v2" -AUTHORIZATION_HEADER = "AUTHORIZATION" -USERID_HEADER = "X-USER-ID" -PLAYHT_TTS_SAMPLE_RATE = 48000 -PLAYHT_TTS_CHANNELS = 1 - -_TTSEncoding = Literal["mp3", "wav", "ogg", "flac", "mulaw"] - - -@dataclass -class _TTSOptions: - api_key: str - user_id: str - voice: Voice - base_url: str - sample_rate: int - encoding: _TTSEncoding - - -class TTS(tts.TTS): - def __init__( - self, - *, - voice: Voice = DEFAULT_VOICE, - api_key: str | None = None, - user_id: str | None = None, - base_url: str | None = None, - encoding: _TTSEncoding = "wav", - http_session: aiohttp.ClientSession | None = None, - ) -> None: - super().__init__( - capabilities=tts.TTSCapabilities( - streaming=False, - ), - sample_rate=PLAYHT_TTS_SAMPLE_RATE, - num_channels=PLAYHT_TTS_CHANNELS, - ) - api_key = api_key or os.environ.get("PLAYHT_API_KEY") - if not api_key: - raise ValueError("PLAYHT_API_KEY must be set") - - user_id = user_id or os.environ.get("PLAYHT_USER_ID") - if not user_id: - raise ValueError("PLAYHT_USER_ID mus be set") - - self._opts = _TTSOptions( - voice=voice, - user_id=user_id, - api_key=api_key, - base_url=base_url or API_BASE_URL_V2, - sample_rate=PLAYHT_TTS_SAMPLE_RATE, - encoding=encoding, - ) - self._session = http_session - - def _ensure_session(self) -> aiohttp.ClientSession: - if not self._session: - self._session = utils.http_context.http_session() - - return self._session - - async def list_voices(self) -> List[Voice]: - async with self._ensure_session().get( - f"{self._opts.base_url}/voices", - headers={ - "accept": "application/json", - AUTHORIZATION_HEADER: self._opts.api_key, - USERID_HEADER: self._opts.user_id, - }, - ) as resp: - return _dict_to_voices_list(await resp.json()) - - def synthesize( - self, - text: str, - *, - conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, - ) -> "ChunkedStream": - return ChunkedStream( - tts=self, - input_text=text, - conn_options=conn_options, - opts=self._opts, - session=self._ensure_session(), - ) - - -class ChunkedStream(tts.ChunkedStream): - """Synthesize using the chunked api endpoint""" - - def __init__( - self, - tts: TTS, - input_text: str, - opts: _TTSOptions, - conn_options: APIConnectOptions, - session: aiohttp.ClientSession, - ) -> None: - super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) - self._opts, self._session = opts, session - - async def _run(self) -> None: - stream = utils.audio.AudioByteStream( - sample_rate=self._opts.sample_rate, num_channels=1 - ) - self._mp3_decoder = utils.codecs.Mp3StreamDecoder() - request_id = utils.shortuuid() - url = f"{API_BASE_URL_V2}/tts/stream" - headers = { - "accept": ACCEPT_HEADER[self._opts.encoding], - "content-type": "application/json", - AUTHORIZATION_HEADER: self._opts.api_key, - USERID_HEADER: self._opts.user_id, - } - json_data = { - "text": self._input_text, - "output_format": self._opts.encoding, - "voice": self._opts.voice.id, - } - try: - async with self._session.post( - url=url, headers=headers, json=json_data - ) as resp: - if not resp.content_type.startswith("audio/"): - content = await resp.text() - logger.error("playHT returned non-audio data: %s", content) - return - - encoding = _encoding_from_format(self._opts.encoding) - if encoding == "mp3": - async for bytes_data, _ in resp.content.iter_chunks(): - for frame in self._mp3_decoder.decode_chunk(bytes_data): - self._event_ch.send_nowait( - tts.SynthesizedAudio( - request_id=request_id, - frame=frame, - ) - ) - else: - async for bytes_data, _ in resp.content.iter_chunks(): - for frame in stream.write(bytes_data): - self._event_ch.send_nowait( - tts.SynthesizedAudio( - request_id=request_id, - frame=frame, - ) - ) - - for frame in stream.flush(): - self._event_ch.send_nowait( - tts.SynthesizedAudio(request_id=request_id, frame=frame) - ) - - except asyncio.TimeoutError as e: - raise APITimeoutError() from e - except aiohttp.ClientResponseError as e: - raise APIStatusError( - message=e.message, - status_code=e.status, - request_id=None, - body=None, - ) from e - except Exception as e: - raise APIConnectionError() from e - - -def _dict_to_voices_list(data: dict[str, Any]): - voices: List[Voice] = [] - for voice in data["text"]: - voices.append( - Voice( - id=voice["id"], name=voice["name"], voice_engine=voice["voice_engine"] - ) - ) - return voices diff --git a/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/version.py b/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/version.py deleted file mode 100644 index 5c4105cd3..000000000 --- a/livekit-plugins/livekit-plugins-playht/livekit/plugins/playht/version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "1.0.1" diff --git a/livekit-plugins/livekit-plugins-playht/package.json b/livekit-plugins/livekit-plugins-playht/package.json deleted file mode 100644 index 7ca0e17e4..000000000 --- a/livekit-plugins/livekit-plugins-playht/package.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "name": "livekit-plugins-playht", - "private": true, - "version": "1.0.1" -} \ No newline at end of file diff --git a/livekit-plugins/livekit-plugins-rag/CHANGELOG.md b/livekit-plugins/livekit-plugins-rag/CHANGELOG.md index 6a7effef5..eaf9b2011 100644 --- a/livekit-plugins/livekit-plugins-rag/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-rag/CHANGELOG.md @@ -1,5 +1,11 @@ # livekit-plugins-rag +## 0.2.3 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + ## 0.2.2 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-rag/livekit/plugins/rag/version.py b/livekit-plugins/livekit-plugins-rag/livekit/plugins/rag/version.py index 2985d9da1..0f3f2ddd4 100644 --- a/livekit-plugins/livekit-plugins-rag/livekit/plugins/rag/version.py +++ b/livekit-plugins/livekit-plugins-rag/livekit/plugins/rag/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.2" +__version__ = "0.2.3" diff --git a/livekit-plugins/livekit-plugins-rag/package.json b/livekit-plugins/livekit-plugins-rag/package.json index 897e16552..1b93c9070 100644 --- a/livekit-plugins/livekit-plugins-rag/package.json +++ b/livekit-plugins/livekit-plugins-rag/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-rag", "private": true, - "version": "0.2.2" + "version": "0.2.3" } diff --git a/livekit-plugins/livekit-plugins-rag/setup.py b/livekit-plugins/livekit-plugins-rag/setup.py index 55c8223a8..00ae59c86 100644 --- a/livekit-plugins/livekit-plugins-rag/setup.py +++ b/livekit-plugins/livekit-plugins-rag/setup.py @@ -47,7 +47,7 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents>=0.11", "annoy>=1.17"], + install_requires=["livekit-agents>=0.12.3", "annoy>=1.17"], package_data={"livekit.plugins.rag": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io", diff --git a/livekit-plugins/livekit-plugins-silero/CHANGELOG.md b/livekit-plugins/livekit-plugins-silero/CHANGELOG.md index affcb3671..09c704068 100644 --- a/livekit-plugins/livekit-plugins-silero/CHANGELOG.md +++ b/livekit-plugins/livekit-plugins-silero/CHANGELOG.md @@ -1,5 +1,11 @@ # livekit-plugins-silero +## 0.7.4 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + ## 0.7.3 ### Patch Changes diff --git a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/version.py b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/version.py index 20d8a2226..6b43cc50e 100644 --- a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/version.py +++ b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.7.3" +__version__ = "0.7.4" diff --git a/livekit-plugins/livekit-plugins-silero/package.json b/livekit-plugins/livekit-plugins-silero/package.json index ae0525c9d..e341b9826 100644 --- a/livekit-plugins/livekit-plugins-silero/package.json +++ b/livekit-plugins/livekit-plugins-silero/package.json @@ -1,5 +1,5 @@ { "name": "livekit-plugins-silero", "private": true, - "version": "0.7.3" + "version": "0.7.4" } diff --git a/livekit-plugins/livekit-plugins-silero/setup.py b/livekit-plugins/livekit-plugins-silero/setup.py index c5202db9c..52bc41ba2 100644 --- a/livekit-plugins/livekit-plugins-silero/setup.py +++ b/livekit-plugins/livekit-plugins-silero/setup.py @@ -47,7 +47,7 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents>=0.11", "onnxruntime>=1.18", "numpy>=1.26"], + install_requires=["livekit-agents>=0.12.3", "onnxruntime>=1.18", "numpy>=1.26"], package_data={ "livekit.plugins.silero.resources": ["silero_vad.onnx"], "livekit.plugins.silero": ["py.typed"], diff --git a/livekit-plugins/livekit-plugins-turn-detector/CHANGELOG.md b/livekit-plugins/livekit-plugins-turn-detector/CHANGELOG.md new file mode 100644 index 000000000..46a9a7fe5 --- /dev/null +++ b/livekit-plugins/livekit-plugins-turn-detector/CHANGELOG.md @@ -0,0 +1,41 @@ +# livekit-plugins-eou + +## 0.3.5 + +### Patch Changes + +- fix int32/64 errors on Windows - [#1285](https://github.com/livekit/agents/pull/1285) ([@nbsp](https://github.com/nbsp)) + +## 0.3.4 + +### Patch Changes + +- add jinja2 dependency to turn detector - [#1277](https://github.com/livekit/agents/pull/1277) ([@davidzhao](https://github.com/davidzhao)) + +## 0.3.3 + +### Patch Changes + +- use quantized onnx version of turn detector model - [#1231](https://github.com/livekit/agents/pull/1231) ([@jeradf](https://github.com/jeradf)) + +- use onnxruntime for turn detection and remove pytorch dependency - [#1257](https://github.com/livekit/agents/pull/1257) ([@jeradf](https://github.com/jeradf)) + +## 0.3.2 + +### Patch Changes + +- improvements to endpointing latency - [#1212](https://github.com/livekit/agents/pull/1212) ([@davidzhao](https://github.com/davidzhao)) + +- Improvements to end of turn plugin, ensure STT language settings. - [#1195](https://github.com/livekit/agents/pull/1195) ([@davidzhao](https://github.com/davidzhao)) + +## 0.3.1 + +### Patch Changes + +- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom)) + +## 0.3.0 + +### Minor Changes + +- feat: inference process & end of utterance plugin - [#1133](https://github.com/livekit/agents/pull/1133) ([@theomonnom](https://github.com/theomonnom)) diff --git a/livekit-plugins/livekit-plugins-turn-detector/README.md b/livekit-plugins/livekit-plugins-turn-detector/README.md new file mode 100644 index 000000000..859b803cf --- /dev/null +++ b/livekit-plugins/livekit-plugins-turn-detector/README.md @@ -0,0 +1,48 @@ +# LiveKit Plugins Turn Detector + +This plugin introduces end-of-turn detection for LiveKit Agents using a custom open-weight model to determine when a user has finished speaking. + +Traditional voice agents use VAD (voice activity detection) for end-of-turn detection. However, VAD models lack language understanding, often causing false positives where the agent interrupts the user before they finish speaking. + +By leveraging a language model specifically trained for this task, this plugin offers a more accurate and robust method for detecting end-of-turns. The current version supports English only and should not be used when targeting other languages. + +## Installation + +```bash +pip install livekit-plugins-turn-detector +``` + +## Usage + +This plugin is designed to be used with the `VoicePipelineAgent`: + +```python +from livekit.plugins import turn_detector + +agent = VoicePipelineAgent( + ... + turn_detector=turn_detector.EOUModel(), +) +``` + +## Running your agent + +This plugin requires model files. Before starting your agent for the first time, or when building Docker images for deployment, run the following command to download the model files: + +```bash +python my_agent.py download-files +``` + +## Model system requirements + +The end-of-turn model is optimized to run on CPUs with modest system requirements. It is designed to run on the same server hosting your agents. On a 4-core server instance, it completes inference in ~50ms with minimal CPU usage. + +The model requires 1.5GB of RAM and runs within a shared inference server, supporting multiple concurrent sessions. + +We are working to reduce the CPU and memory requirements in future releases. + +## License + +The plugin source code is licensed under the Apache-2.0 license. + +The end-of-turn model is licensed under the [LiveKit Model License](https://huggingface.co/livekit/turn-detector/blob/main/LICENSE). diff --git a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/__init__.py b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/__init__.py new file mode 100644 index 000000000..54d7a90af --- /dev/null +++ b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from livekit.agents import Plugin +from livekit.agents.inference_runner import _InferenceRunner + +from .eou import EOUModel, _EUORunner +from .log import logger +from .version import __version__ + +__all__ = ["EOUModel", "__version__"] + + +class EOUPlugin(Plugin): + def __init__(self): + super().__init__(__name__, __version__, __package__, logger) + + def download_files(self) -> None: + from transformers import AutoTokenizer + + from .eou import HG_MODEL, ONNX_FILENAME, _download_from_hf_hub + + AutoTokenizer.from_pretrained(HG_MODEL) + _download_from_hf_hub(HG_MODEL, ONNX_FILENAME) + + +Plugin.register_plugin(EOUPlugin()) +_InferenceRunner.register_runner(_EUORunner) diff --git a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py new file mode 100644 index 000000000..8c8090946 --- /dev/null +++ b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import json +import string +import time + +import numpy as np +from livekit.agents import llm +from livekit.agents.inference_runner import _InferenceRunner +from livekit.agents.ipc.inference_executor import InferenceExecutor +from livekit.agents.job import get_current_job_context + +from .log import logger + +HG_MODEL = "livekit/turn-detector" +ONNX_FILENAME = "model_quantized.onnx" +PUNCS = string.punctuation.replace("'", "") +MAX_HISTORY = 4 + + +def _download_from_hf_hub(repo_id, filename, **kwargs): + from huggingface_hub import hf_hub_download + + local_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs) + return local_path + + +def _softmax(logits: np.ndarray) -> np.ndarray: + exp_logits = np.exp(logits - np.max(logits)) + return exp_logits / np.sum(exp_logits) + + +class _EUORunner(_InferenceRunner): + INFERENCE_METHOD = "lk_end_of_utterance" + + def _normalize(self, text): + def strip_puncs(text): + return text.translate(str.maketrans("", "", PUNCS)) + + return " ".join(strip_puncs(text).lower().split()) + + def _format_chat_ctx(self, chat_ctx: dict): + new_chat_ctx = [] + for msg in chat_ctx: + content = self._normalize(msg["content"]) + + if not content: + continue + + msg["content"] = content + new_chat_ctx.append(msg) + + convo_text = self._tokenizer.apply_chat_template( + new_chat_ctx, + add_generation_prompt=False, + add_special_tokens=False, + tokenize=False, + ) + + # remove the EOU token from current utterance + ix = convo_text.rfind("<|im_end|>") + text = convo_text[:ix] + return text + + def initialize(self) -> None: + import onnxruntime as ort + from huggingface_hub import errors + from transformers import AutoTokenizer + + try: + local_path_onnx = _download_from_hf_hub( + HG_MODEL, ONNX_FILENAME, local_files_only=True + ) + self._session = ort.InferenceSession( + local_path_onnx, providers=["CPUExecutionProvider"] + ) + + self._tokenizer = AutoTokenizer.from_pretrained( + HG_MODEL, local_files_only=True + ) + self._eou_index = self._tokenizer.encode("<|im_end|>")[-1] + except (errors.LocalEntryNotFoundError, OSError): + logger.error( + ( + f"Could not find model {HG_MODEL}. Make sure you have downloaded the model before running the agent. " + "Use `python3 your_agent.py download-files` to download the models." + ) + ) + raise RuntimeError( + f"livekit-plugins-turn-detector initialization failed. Could not find model {HG_MODEL}." + ) from None + + def run(self, data: bytes) -> bytes | None: + data_json = json.loads(data) + chat_ctx = data_json.get("chat_ctx", None) + + if not chat_ctx: + raise ValueError("chat_ctx is required on the inference input data") + + start_time = time.perf_counter() + + text = self._format_chat_ctx(chat_ctx) + inputs = self._tokenizer( + text, + add_special_tokens=False, + return_tensors="np", + ) + + input_dict = {"input_ids": np.array(inputs["input_ids"], dtype=np.int64)} + + # Run inference + outputs = self._session.run(["logits"], input_dict) + + logits = outputs[0][0, -1, :] + probs = _softmax(logits) + eou_probability = probs[self._eou_index] + + end_time = time.perf_counter() + + logger.debug( + "eou prediction", + extra={ + "eou_probability": eou_probability, + "input": text, + "duration": round(end_time - start_time, 3), + }, + ) + return json.dumps({"eou_probability": float(eou_probability)}).encode() + + +class EOUModel: + def __init__( + self, + inference_executor: InferenceExecutor | None = None, + unlikely_threshold: float = 0.15, + ) -> None: + self._executor = ( + inference_executor or get_current_job_context().inference_executor + ) + self._unlikely_threshold = unlikely_threshold + + def unlikely_threshold(self) -> float: + return self._unlikely_threshold + + def supports_language(self, language: str | None) -> bool: + if language is None: + return False + parts = language.lower().split("-") + # certain models use language codes (DG, AssemblyAI), others use full names (like OAI) + return parts[0] == "en" or parts[0] == "english" + + async def predict_eou(self, chat_ctx: llm.ChatContext) -> float: + return await self.predict_end_of_turn(chat_ctx) + + async def predict_end_of_turn(self, chat_ctx: llm.ChatContext) -> float: + messages = [] + + for msg in chat_ctx.messages: + if msg.role not in ("user", "assistant"): + continue + + if isinstance(msg.content, str): + messages.append( + { + "role": msg.role, + "content": msg.content, + } + ) + elif isinstance(msg.content, list): + for cnt in msg.content: + if isinstance(cnt, str): + messages.append( + { + "role": msg.role, + "content": cnt, + } + ) + break + + messages = messages[-MAX_HISTORY:] + + json_data = json.dumps({"chat_ctx": messages}).encode() + result = await self._executor.do_inference( + _EUORunner.INFERENCE_METHOD, json_data + ) + + assert ( + result is not None + ), "end_of_utterance prediction should always returns a result" + + result_json = json.loads(result.decode()) + return result_json["eou_probability"] diff --git a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/log.py b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/log.py new file mode 100644 index 000000000..2b29634ad --- /dev/null +++ b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/log.py @@ -0,0 +1,3 @@ +import logging + +logger = logging.getLogger("livekit.plugins.turn_detector") diff --git a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/version.py b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/version.py new file mode 100644 index 000000000..4be9d79b7 --- /dev/null +++ b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/version.py @@ -0,0 +1,15 @@ +# Copyright 2023 LiveKit, Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.3.5" diff --git a/livekit-plugins/livekit-plugins-turn-detector/package.json b/livekit-plugins/livekit-plugins-turn-detector/package.json new file mode 100644 index 000000000..264da83bf --- /dev/null +++ b/livekit-plugins/livekit-plugins-turn-detector/package.json @@ -0,0 +1,5 @@ +{ + "name": "livekit-plugins-turn-detector", + "private": true, + "version": "0.3.5" +} diff --git a/livekit-plugins/livekit-plugins-turn-detector/pyproject.toml b/livekit-plugins/livekit-plugins-turn-detector/pyproject.toml new file mode 100644 index 000000000..8cf32563a --- /dev/null +++ b/livekit-plugins/livekit-plugins-turn-detector/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/livekit-plugins/livekit-plugins-turn-detector/setup.py b/livekit-plugins/livekit-plugins-turn-detector/setup.py new file mode 100644 index 000000000..1585ed0cf --- /dev/null +++ b/livekit-plugins/livekit-plugins-turn-detector/setup.py @@ -0,0 +1,65 @@ +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib + +import setuptools +import setuptools.command.build_py + +here = pathlib.Path(__file__).parent.resolve() +about = {} +with open( + os.path.join(here, "livekit", "plugins", "turn_detector", "version.py"), "r" +) as f: + exec(f.read(), about) + + +setuptools.setup( + name="livekit-plugins-turn-detector", + version=about["__version__"], + description="End of utterance detection for LiveKit Agents", + long_description=(here / "README.md").read_text(encoding="utf-8"), + long_description_content_type="text/markdown", + url="https://github.com/livekit/agents", + cmdclass={}, + classifiers=[ + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Topic :: Multimedia :: Sound/Audio", + "Topic :: Multimedia :: Video", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3 :: Only", + ], + keywords=["webrtc", "realtime", "audio", "video", "livekit"], + license="Apache-2.0", + packages=setuptools.find_namespace_packages(include=["livekit.*"]), + python_requires=">=3.9.0", + install_requires=[ + "livekit-agents>=0.12.3", + "transformers>=4.47.1", + "numpy>=1.26", + "onnxruntime>=1.18", + "jinja2", + ], + package_data={"livekit.plugins.turn_detector": ["py.typed"]}, + project_urls={ + "Documentation": "https://docs.livekit.io", + "Website": "https://livekit.io/", + "Source": "https://github.com/livekit/agents", + }, +) diff --git a/tests/.gitattributes b/tests/.gitattributes index 0fd91ce6d..83117e69b 100644 --- a/tests/.gitattributes +++ b/tests/.gitattributes @@ -1,2 +1,5 @@ long.mp3 filter=lfs diff=lfs merge=lfs -text change-sophie.wav filter=lfs diff=lfs merge=lfs -text +change-sophie.opus filter=lfs diff=lfs merge=lfs -text +hearts.rgba filter=lfs diff=lfs merge=lfs -text +hearts.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/tests/change-sophie.opus b/tests/change-sophie.opus new file mode 100644 index 000000000..5112fcab5 --- /dev/null +++ b/tests/change-sophie.opus @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a2eb5667dc35714b4cb70324d3722f89580885ee5e51be5f2c793e7893d9a24 +size 48905 diff --git a/tests/fake_stt.py b/tests/fake_stt.py new file mode 100644 index 000000000..0b365100c --- /dev/null +++ b/tests/fake_stt.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import asyncio + +from livekit.agents import NOT_GIVEN, NotGivenOr, utils +from livekit.agents.stt import ( + STT, + RecognizeStream, + SpeechData, + SpeechEvent, + SpeechEventType, + STTCapabilities, +) +from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions +from livekit.agents.utils.audio import AudioBuffer + + +class RecognizeSentinel: ... + + +class FakeSTT(STT): + def __init__( + self, + *, + fake_exception: Exception | None = None, + fake_transcript: str | None = None, + fake_timeout: float | None = None, + ) -> None: + super().__init__( + capabilities=STTCapabilities(streaming=True, interim_results=False), + ) + + self._fake_exception = fake_exception + self._fake_transcript = fake_transcript + self._fake_timeout = fake_timeout + + self._recognize_ch = utils.aio.Chan[RecognizeSentinel]() + self._stream_ch = utils.aio.Chan[FakeRecognizeStream]() + + def update_options( + self, + *, + fake_exception: NotGivenOr[Exception | None] = NOT_GIVEN, + fake_transcript: NotGivenOr[str | None] = NOT_GIVEN, + fake_timeout: NotGivenOr[float | None] = NOT_GIVEN, + ) -> None: + if utils.is_given(fake_exception): + self._fake_exception = fake_exception + + if utils.is_given(fake_transcript): + self._fake_transcript = fake_transcript + + if utils.is_given(fake_timeout): + self._fake_timeout = fake_timeout + + @property + def recognize_ch(self) -> utils.aio.ChanReceiver[RecognizeSentinel]: + return self._recognize_ch + + @property + def stream_ch(self) -> utils.aio.ChanReceiver["FakeRecognizeStream"]: + return self._stream_ch + + async def _recognize_impl( + self, + buffer: AudioBuffer, + *, + language: str | None, + conn_options: APIConnectOptions, + ) -> SpeechEvent: + if self._fake_timeout is not None: + await asyncio.sleep(self._fake_timeout) + + if self._fake_exception is not None: + raise self._fake_exception + + return SpeechEvent( + type=SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[ + SpeechData(text=self._fake_transcript or "", language=language or "") + ], + ) + + async def recognize( + self, + buffer: AudioBuffer, + *, + language: str | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ): + self._recognize_ch.send_nowait(RecognizeSentinel()) + return await super().recognize( + buffer, language=language, conn_options=conn_options + ) + + def stream( + self, + *, + language: str | None = None, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> "FakeRecognizeStream": + stream = FakeRecognizeStream( + stt=self, + conn_options=conn_options, + ) + self._stream_ch.send_nowait(stream) + return stream + + +class FakeRecognizeStream(RecognizeStream): + def __init__( + self, + *, + stt: STT, + conn_options: APIConnectOptions, + ): + super().__init__(stt=stt, conn_options=conn_options) + self._attempt = 0 + + @property + def attempt(self) -> int: + return self._attempt + + def send_fake_transcript(self, transcript: str) -> None: + self._event_ch.send_nowait( + SpeechEvent( + type=SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[SpeechData(text=transcript, language="")], + ) + ) + + async def _run(self) -> None: + self._attempt += 1 + assert isinstance(self._stt, FakeSTT) + + if self._stt._fake_timeout is not None: + await asyncio.sleep(self._stt._fake_timeout) + + if self._stt._fake_transcript is not None: + self.send_fake_transcript(self._stt._fake_transcript) + + async for _ in self._input_ch: + pass + + if self._stt._fake_exception is not None: + raise self._stt._fake_exception diff --git a/tests/fake_tts.py b/tests/fake_tts.py index 0ae4f6d2d..9d4b5f70c 100644 --- a/tests/fake_tts.py +++ b/tests/fake_tts.py @@ -149,13 +149,25 @@ async def _run(self) -> None: assert isinstance(self._tts, FakeTTS) - request_id = utils.shortuuid("fake_tts_") - segment_id = utils.shortuuid("fake_segment_") - if self._tts._fake_timeout is not None: await asyncio.sleep(self._tts._fake_timeout) - if self._tts._fake_audio_duration is not None: + has_data = False + async for data in self._input_ch: + if isinstance(data, str): + has_data = True + continue + elif isinstance(data, SynthesizeStream._FlushSentinel) and not has_data: + continue + + has_data = False + + if self._tts._fake_audio_duration is None: + continue + + request_id = utils.shortuuid("fake_tts_") + segment_id = utils.shortuuid("fake_segment_") + pushed_samples = 0 max_samples = ( int(self._tts.sample_rate * self._tts._fake_audio_duration + 0.5) @@ -180,8 +192,5 @@ async def _run(self) -> None: ) pushed_samples += num_samples - async for _ in self._input_ch: - pass - if self._tts._fake_exception is not None: raise self._tts._fake_exception diff --git a/tests/hearts.jpg b/tests/hearts.jpg new file mode 100644 index 000000000..23ecdb8d1 --- /dev/null +++ b/tests/hearts.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d133e5535163b77b4ea65d4ca7c9dbe81f4a24fad530f24b9a31b3bde1e1c38 +size 151017 diff --git a/tests/hearts.rgba b/tests/hearts.rgba new file mode 100644 index 000000000..d40a5334b --- /dev/null +++ b/tests/hearts.rgba @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06087a10c1864e6644d16a6e508852e678ad1a96e4d99bd8056bb7f60ab765cc +size 1048576 diff --git a/tests/pytest.ini b/tests/pytest.ini index 145cb7ebb..eb002c8ce 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,5 +1,6 @@ [pytest] asyncio_mode = auto +timeout = 120 asyncio_default_fixture_loop_scope = "function" log_cli = 1 log_cli_level = INFO diff --git a/tests/test_build_func_desc.py b/tests/test_build_func_desc.py new file mode 100644 index 000000000..67659df3b --- /dev/null +++ b/tests/test_build_func_desc.py @@ -0,0 +1,51 @@ +import sys +from inspect import _empty +from typing import List, Optional, Union + +import pytest +from livekit.agents.llm import FunctionArgInfo, FunctionInfo +from livekit.agents.llm.function_context import _is_optional_type +from livekit.plugins.openai import _oai_api + + +def test_typing(): + assert _is_optional_type(Optional[int]) == (True, int) + assert _is_optional_type(Union[str, None]) == (True, str) + if sys.version_info >= (3, 10): + assert _is_optional_type(float | None) == (True, float) + assert _is_optional_type(Union[str, int]) == (False, None) + + +@pytest.mark.parametrize( + ("arg_typ", "oai_type"), + [ + pytest.param(int, "number", id="int"), + pytest.param(Optional[int], "number", id="optional[int]"), + pytest.param(Union[None, int], "number", id="union[none, int]"), + pytest.param(Union[str, None], "string", id="union[str, none]"), + pytest.param(List[int], "array", id="list[int]"), + pytest.param(Optional[List[int]], "array", id="optional[list[int]]"), + ], +) +def test_description_building(arg_typ: type, oai_type: str): + fi = FunctionInfo( + name="foo", + description="foo", + auto_retry=False, + callable=lambda: None, + arguments={ + "arg": FunctionArgInfo( + name="foo", + description="foo", + type=arg_typ, + default=_empty, + choices=(), + ), + }, + ) + assert ( + _oai_api.build_oai_function_description(fi)["function"]["parameters"][ + "properties" + ]["foo"]["type"] + == oai_type + ) diff --git a/tests/test_create_func.py b/tests/test_create_func.py new file mode 100644 index 000000000..a81d31d93 --- /dev/null +++ b/tests/test_create_func.py @@ -0,0 +1,228 @@ +import enum +from inspect import _empty +from typing import Annotated, List, Optional + +import pytest +from livekit.agents import llm +from livekit.plugins.openai import _oai_api + + +def test_func_basic(): + class TestFunctionContext(llm.FunctionContext): + @llm.ai_callable(name="test_function", description="A simple test function") + def test_fn( + self, param: Annotated[str, llm.TypeInfo(description="A string parameter")] + ): + pass + + fnc_ctx = TestFunctionContext() + assert ( + "test_function" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" + + fnc_info = fnc_ctx.ai_functions["test_function"] + build_info = _oai_api.build_oai_function_description(fnc_info) + assert fnc_info.name == build_info["function"]["name"] + assert fnc_info.description == build_info["function"]["description"] + assert not fnc_info.auto_retry + assert "param" in fnc_info.arguments + assert "param" in build_info["function"]["parameters"]["properties"] + assert "param" in build_info["function"]["parameters"]["required"] + + arg_info = fnc_info.arguments["param"] + build_arg_info = build_info["function"]["parameters"]["properties"]["param"] + + assert arg_info.name == "param" + assert arg_info.description == "A string parameter" + assert arg_info.type is str + assert arg_info.default is _empty + assert arg_info.choices == () + assert build_arg_info["description"] == arg_info.description + assert build_arg_info["type"] == "string" + + +def test_func_duplicate(): + class TestFunctionContext(llm.FunctionContext): + @llm.ai_callable( + name="duplicate_function", description="A simple test function" + ) + def fn1(self): + pass + + @llm.ai_callable( + name="duplicate_function", description="A simple test function" + ) + def fn2(self): + pass + + with pytest.raises( + ValueError, match="duplicate ai_callable name: duplicate_function" + ): + TestFunctionContext() + + +def test_func_with_docstring(): + class TestFunctionContext(llm.FunctionContext): + @llm.ai_callable() + def test_fn(self): + """A simple test function""" + pass + + fnc_ctx = TestFunctionContext() + assert ( + "test_fn" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" + + assert fnc_ctx.ai_functions["test_fn"].description == "A simple test function" + + +def test_func_with_optional_parameter(): + class TestFunctionContext(llm.FunctionContext): + @llm.ai_callable( + name="optional_function", description="Function with optional parameter" + ) + def optional_fn( + self, + param: Annotated[ + Optional[int], llm.TypeInfo(description="An optional integer parameter") + ] = None, + param2: Optional[List[str]] = None, + param3: str = "A string", + ): + pass + + fnc_ctx = TestFunctionContext() + assert ( + "optional_function" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" + + fnc_info = fnc_ctx.ai_functions["optional_function"] + build_info = _oai_api.build_oai_function_description(fnc_info) + print(build_info) + assert fnc_info.name == build_info["function"]["name"] + assert fnc_info.description == build_info["function"]["description"] + assert "param" in fnc_info.arguments + assert "param2" in fnc_info.arguments + assert "param3" in fnc_info.arguments + assert "param" in build_info["function"]["parameters"]["properties"] + assert "param2" in build_info["function"]["parameters"]["properties"] + assert "param3" in build_info["function"]["parameters"]["properties"] + assert "param" not in build_info["function"]["parameters"]["required"] + assert "param2" not in build_info["function"]["parameters"]["required"] + assert "param3" not in build_info["function"]["parameters"]["required"] + + # Check 'param' + arg_info = fnc_info.arguments["param"] + build_arg_info = build_info["function"]["parameters"]["properties"]["param"] + + assert arg_info.name == "param" + assert arg_info.description == "An optional integer parameter" + assert arg_info.type == Optional[int] + assert arg_info.default is None + assert arg_info.choices == () + assert build_arg_info["description"] == arg_info.description + assert build_arg_info["type"] == "number" + + # Check 'param2' + arg_info = fnc_info.arguments["param2"] + build_arg_info = build_info["function"]["parameters"]["properties"]["param2"] + + assert arg_info.name == "param2" + assert arg_info.description == "" + assert arg_info.type == Optional[List[str]] + assert arg_info.default is None + assert arg_info.choices == () + assert build_arg_info["type"] == "array" + assert build_arg_info["items"]["type"] == "string" + + # check 'param3' + arg_info = fnc_info.arguments["param3"] + build_arg_info = build_info["function"]["parameters"]["properties"]["param3"] + + assert arg_info.name == "param3" + assert arg_info.description == "" + assert arg_info.type is str + assert arg_info.default == "A string" + assert arg_info.choices == () + assert build_arg_info["type"] == "string" + + +def test_func_with_list_parameter(): + class TestFunctionContext(llm.FunctionContext): + @llm.ai_callable( + name="list_function", description="Function with list parameter" + ) + def list_fn( + self, + items: Annotated[List[str], llm.TypeInfo(description="A list of strings")], + ): + pass + + fnc_ctx = TestFunctionContext() + assert ( + "list_function" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" + + fnc_info = fnc_ctx.ai_functions["list_function"] + build_info = _oai_api.build_oai_function_description(fnc_info) + assert fnc_info.name == build_info["function"]["name"] + assert fnc_info.description == build_info["function"]["description"] + assert not fnc_info.auto_retry + assert "items" in fnc_info.arguments + assert "items" in build_info["function"]["parameters"]["properties"] + assert "items" in build_info["function"]["parameters"]["required"] + + arg_info = fnc_info.arguments["items"] + build_arg_info = build_info["function"]["parameters"]["properties"]["items"] + + assert arg_info.name == "items" + assert arg_info.description == "A list of strings" + assert arg_info.type is List[str] + assert arg_info.default is _empty + assert arg_info.choices == () + assert build_arg_info["description"] == arg_info.description + assert build_arg_info["type"] == "array" + assert build_arg_info["items"]["type"] == "string" + + +def test_func_with_enum_parameter(): + class Status(enum.Enum): + ACTIVE = "active" + INACTIVE = "inactive" + PENDING = "pending" + + class TestFunctionContext(llm.FunctionContext): + @llm.ai_callable( + name="enum_function", description="Function with enum parameter" + ) + def enum_fn( + self, + status: Annotated[Status, llm.TypeInfo(description="Status of the entity")], + ): + pass + + fnc_ctx = TestFunctionContext() + assert ( + "enum_function" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" + + fnc_info = fnc_ctx.ai_functions["enum_function"] + build_info = _oai_api.build_oai_function_description(fnc_info) + assert fnc_info.name == build_info["function"]["name"] + assert fnc_info.description == build_info["function"]["description"] + assert not fnc_info.auto_retry + assert "status" in fnc_info.arguments + assert "status" in build_info["function"]["parameters"]["properties"] + assert "status" in build_info["function"]["parameters"]["required"] + + arg_info = fnc_info.arguments["status"] + build_arg_info = build_info["function"]["parameters"]["properties"]["status"] + + assert arg_info.name == "status" + assert arg_info.description == "Status of the entity" + assert arg_info.type is str # Enum values are converted to their underlying type + assert arg_info.default is _empty + assert arg_info.choices == ("active", "inactive", "pending") + assert build_arg_info["description"] == arg_info.description + assert build_arg_info["type"] == "string" + assert build_arg_info["enum"] == arg_info.choices diff --git a/tests/test_decoder.py b/tests/test_decoder.py new file mode 100644 index 000000000..10b5b521d --- /dev/null +++ b/tests/test_decoder.py @@ -0,0 +1,149 @@ +import os +import threading +import time +from concurrent.futures import ThreadPoolExecutor + +import aiohttp +import pytest +from livekit.agents.stt import SpeechEventType +from livekit.agents.utils.codecs import AudioStreamDecoder, StreamBuffer +from livekit.plugins import deepgram + +from .utils import wer + +TEST_AUDIO_FILEPATH = os.path.join(os.path.dirname(__file__), "change-sophie.opus") + + +@pytest.mark.asyncio +async def test_decode_and_transcribe(): + # Skip if test file doesn't exist + if not os.path.exists(TEST_AUDIO_FILEPATH): + pytest.skip(f"Test file not found: {TEST_AUDIO_FILEPATH}") + + decoder = AudioStreamDecoder() + with open(TEST_AUDIO_FILEPATH, "rb") as f: + opus_data = f.read() + decoder.push(opus_data) + decoder.end_input() + + session = aiohttp.ClientSession() + stt = deepgram.STT(http_session=session) + stream = stt.stream() + + # Push frames to STT + async for frame in decoder: + stream.push_frame(frame) + + # Mark end of input + stream.end_input() + + # Collect results + final_text = "" + async for event in stream: + if event.type == SpeechEventType.FINAL_TRANSCRIPT: + if event.alternatives: + if final_text: + final_text += " " + final_text += event.alternatives[0].text + + await decoder.aclose() + await stream.aclose() + await session.close() + + # Verify the transcription + expected_text = "the people that are crazy enough to think they can change the world are the ones who do" + assert wer(final_text, expected_text) < 0.2 + + +def test_stream_buffer(): + buffer = StreamBuffer() + data_chunks = [b"hello", b"world", b"test", b"data"] + received_data = bytearray() + write_completed = threading.Event() + + def writer(): + for chunk in data_chunks: + buffer.write(chunk) + time.sleep(0.01) # Simulate some processing time + buffer.end_input() + write_completed.set() + + def reader(): + while True: + data = buffer.read(4) # Read in small chunks + if not data: # EOF + break + received_data.extend(data) + + # Run writer and reader in separate threads + with ThreadPoolExecutor(max_workers=2) as executor: + reader_future = executor.submit(reader) + writer_future = executor.submit(writer) + + # Wait for both threads to complete + writer_future.result() + reader_future.result() + + # Verify that all data was received correctly + expected_data = b"".join(data_chunks) + assert bytes(received_data) == expected_data + + +def test_stream_buffer_large_chunks(): + import hashlib + + buffer = StreamBuffer() + large_chunk = os.urandom(1024 * 1024) # 1MB of random bytes + num_chunks = 5 + total_size = 0 + write_completed = threading.Event() + input_hasher = hashlib.sha256() + + def writer(): + nonlocal total_size + for _ in range(num_chunks): + buffer.write(large_chunk) + total_size += len(large_chunk) + input_hasher.update(large_chunk) + buffer.end_input() + write_completed.set() + + received_size = 0 + output_hasher = hashlib.sha256() + + def reader(): + nonlocal received_size + # allow writer to start first + time.sleep(1) + while True: + chunk = buffer.read(8192) # Read in 8KB chunks + if not chunk: + break + received_size += len(chunk) + output_hasher.update(chunk) + + # Run writer and reader in separate threads + with ThreadPoolExecutor(max_workers=2) as executor: + reader_future = executor.submit(reader) + writer_future = executor.submit(writer) + + # Wait for both threads to complete + writer_future.result() + reader_future.result() + + assert received_size == total_size + assert total_size == num_chunks * len(large_chunk) + assert input_hasher.hexdigest() == output_hasher.hexdigest() + + +def test_stream_buffer_early_close(): + buffer = StreamBuffer() + + # Write some data + buffer.write(b"test data") + + # Close the buffer + buffer.close() + + # Reading from closed buffer should return empty bytes + assert buffer.read() == b"" diff --git a/tests/test_ipc.py b/tests/test_ipc.py index d77715dde..4e1fd4fe7 100644 --- a/tests/test_ipc.py +++ b/tests/test_ipc.py @@ -114,6 +114,7 @@ def _generate_fake_job() -> job.RunningJobInfo: url="fake_url", token="fake_token", accept_arguments=job.JobAcceptArguments(name="", identity="", metadata=""), + worker_id="fake_id", ) @@ -141,7 +142,7 @@ def _new_start_args(mp_ctx: BaseContext) -> _StartArgs: def _initialize_proc(proc: JobProcess) -> None: - start_args: _StartArgs = proc.start_arguments + start_args: _StartArgs = proc.user_arguments # incrementing isn't atomic (the lock should be reentrant by default) with start_args.initialize_counter.get_lock(): @@ -154,7 +155,7 @@ def _initialize_proc(proc: JobProcess) -> None: async def _job_entrypoint(job_ctx: JobContext) -> None: - start_args: _StartArgs = job_ctx.proc.start_arguments + start_args: _StartArgs = job_ctx.proc.user_arguments async def _job_shutdown() -> None: with start_args.shutdown_counter.get_lock(): @@ -196,6 +197,9 @@ async def test_proc_pool(): job_executor_type=job.JobExecutorType.PROCESS, initialize_timeout=20.0, close_timeout=20.0, + inference_executor=None, + memory_warn_mb=0, + memory_limit_mb=0, mp_ctx=mp_ctx, loop=loop, ) @@ -210,21 +214,21 @@ async def test_proc_pool(): exitcodes = [] @pool.on("process_created") - def _process_created(proc: ipc.proc_job_executor.ProcJobExecutor): + def _process_created(proc: ipc.job_proc_executor.ProcJobExecutor): created_q.put_nowait(None) - proc.start_arguments = start_args + proc.user_arguments = start_args @pool.on("process_started") - def _process_started(proc: ipc.proc_job_executor.ProcJobExecutor): + def _process_started(proc: ipc.job_proc_executor.ProcJobExecutor): start_q.put_nowait(None) pids.append(proc.pid) @pool.on("process_ready") - def _process_ready(proc: ipc.proc_job_executor.ProcJobExecutor): + def _process_ready(proc: ipc.job_proc_executor.ProcJobExecutor): ready_q.put_nowait(None) @pool.on("process_closed") - def _process_closed(proc: ipc.proc_job_executor.ProcJobExecutor): + def _process_closed(proc: ipc.job_proc_executor.ProcJobExecutor): close_q.put_nowait(None) exitcodes.append(proc.exitcode) @@ -272,6 +276,9 @@ async def test_slow_initialization(): num_idle_processes=num_idle_processes, initialize_timeout=1.0, close_timeout=20.0, + inference_executor=None, + memory_warn_mb=0, + memory_limit_mb=0, mp_ctx=mp_ctx, loop=loop, ) @@ -285,12 +292,12 @@ async def test_slow_initialization(): exitcodes = [] @pool.on("process_created") - def _process_created(proc: ipc.proc_job_executor.ProcJobExecutor): - proc.start_arguments = start_args + def _process_created(proc: ipc.job_proc_executor.ProcJobExecutor): + proc.user_arguments = start_args start_q.put_nowait(None) @pool.on("process_closed") - def _process_closed(proc: ipc.proc_job_executor.ProcJobExecutor): + def _process_closed(proc: ipc.job_proc_executor.ProcJobExecutor): close_q.put_nowait(None) pids.append(proc.pid) exitcodes.append(proc.exitcode) @@ -316,18 +323,24 @@ def _create_proc( close_timeout: float, mp_ctx: BaseContext, initialize_timeout: float = 20.0, -) -> tuple[ipc.proc_job_executor.ProcJobExecutor, _StartArgs]: +) -> tuple[ipc.job_proc_executor.ProcJobExecutor, _StartArgs]: start_args = _new_start_args(mp_ctx) loop = asyncio.get_running_loop() - proc = ipc.proc_job_executor.ProcJobExecutor( + proc = ipc.job_proc_executor.ProcJobExecutor( initialize_process_fnc=_initialize_proc, job_entrypoint_fnc=_job_entrypoint, initialize_timeout=initialize_timeout, close_timeout=close_timeout, + memory_warn_mb=0, + memory_limit_mb=0, + ping_interval=2.5, + ping_timeout=10.0, + high_ping_threshold=1.0, + inference_executor=None, mp_ctx=mp_ctx, loop=loop, ) - proc.start_arguments = start_args + proc.user_arguments = start_args return proc, start_args diff --git a/tests/test_llm.py b/tests/test_llm.py index 3a0e1ea68..4b71c0324 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -1,13 +1,16 @@ from __future__ import annotations import asyncio +import base64 from enum import Enum -from typing import Annotated, Callable, Optional +from pathlib import Path +from typing import Annotated, Callable, Literal, Optional, Union import pytest -from livekit.agents import llm +from livekit.agents import APIConnectionError, llm from livekit.agents.llm import ChatContext, FunctionContext, TypeInfo, ai_callable from livekit.plugins import anthropic, openai +from livekit.rtc import VideoBufferType, VideoFrame class Unit(Enum): @@ -32,9 +35,7 @@ def get_weather( @ai_callable(description="Play a music") def play_music( self, - name: Annotated[ - str, TypeInfo(description="The artist and the name of the song") - ], + name: Annotated[str, TypeInfo(description="the name of the Artist")], ) -> None: ... # test for cancelled calls @@ -47,7 +48,7 @@ async def toggle_light( await asyncio.sleep(60) # used to test arrays as arguments - @ai_callable(description="Currencies of a specific area") + @ai_callable(description="Select currencies of a specific area") def select_currencies( self, currencies: Annotated[ @@ -79,7 +80,7 @@ def test_hashable_typeinfo(): LLMS: list[Callable[[], llm.LLM]] = [ - lambda: openai.LLM(), + pytest.param(lambda: openai.LLM(), id="openai"), # lambda: openai.beta.AssistantLLM( # assistant_opts=openai.beta.AssistantOptions( # create_options=openai.beta.AssistantCreateOptions( @@ -89,8 +90,8 @@ def test_hashable_typeinfo(): # ) # ) # ), - lambda: anthropic.LLM(), - lambda: openai.LLM.with_vertex(), + pytest.param(lambda: anthropic.LLM(), id="anthropic"), + pytest.param(lambda: openai.LLM.with_vertex(), id="openai.with_vertex"), ] @@ -205,7 +206,7 @@ async def test_calls_arrays(llm_factory: Callable[[], llm.LLM]): stream = await _request_fnc_call( input_llm, - "Can you select all currencies in Europe at once from given choices?", + "Can you select all currencies in Europe at once from given choices using function call `select_currencies`?", fnc_ctx, temperature=0.2, ) @@ -237,7 +238,7 @@ def change_volume( ) -> None: ... if not input_llm.capabilities.supports_choices_on_int: - with pytest.raises(ValueError, match="which is not supported by this model"): + with pytest.raises(APIConnectionError): stream = await _request_fnc_call(input_llm, "Set the volume to 30", fnc_ctx) else: stream = await _request_fnc_call(input_llm, "Set the volume to 30", fnc_ctx) @@ -277,11 +278,82 @@ async def test_optional_args(llm_factory: Callable[[], llm.LLM]): assert address is None, "update_user_info should have been called with address None" +test_tool_choice_cases = [ + pytest.param( + "Default tool_choice (auto)", + "Get the weather for New York and play some music from the artist 'The Beatles'.", + None, + {"get_weather", "play_music"}, + id="Default tool_choice (auto)", + ), + pytest.param( + "Tool_choice set to 'required'", + "Get the weather for Chicago and play some music from the artist 'Eminem'.", + "required", + {"get_weather", "play_music"}, + id="Tool_choice set to 'required'", + ), + pytest.param( + "Tool_choice set to a specific tool ('get_weather')", + "Get the weather for Miami.", + llm.ToolChoice(type="function", name="get_weather"), + {"get_weather"}, + id="Tool_choice set to a specific tool ('get_weather')", + ), + pytest.param( + "Tool_choice set to 'none'", + "Get the weather for Seattle and play some music from the artist 'Frank Sinatra'.", + "none", + set(), # No tool calls expected + id="Tool_choice set to 'none'", + ), +] + + +@pytest.mark.parametrize( + "description, user_request, tool_choice, expected_calls", test_tool_choice_cases +) +@pytest.mark.parametrize("llm_factory", LLMS) +async def test_tool_choice_options( + description: str, + user_request: str, + tool_choice: Union[dict, str, None], + expected_calls: set, + llm_factory: Callable[[], llm.LLM], +): + input_llm = llm_factory() + fnc_ctx = FncCtx() + + stream = await _request_fnc_call( + input_llm, + user_request, + fnc_ctx, + tool_choice=tool_choice, + parallel_tool_calls=True, + ) + + calls = stream.execute_functions() + await asyncio.gather(*[f.task for f in calls], return_exceptions=True) + await stream.aclose() + print(calls) + + call_names = {call.call_info.function_info.name for call in calls} + if tool_choice == "none" and isinstance(input_llm, anthropic.LLM): + assert True + else: + assert ( + call_names == expected_calls + ), f"Test '{description}' failed: Expected calls {expected_calls}, but got {call_names}" + + async def _request_fnc_call( model: llm.LLM, request: str, fnc_ctx: FncCtx, temperature: float | None = None, + parallel_tool_calls: bool | None = None, + tool_choice: Union[llm.ToolChoice, Literal["auto", "required", "none"]] + | None = None, ) -> llm.LLMStream: stream = model.chat( chat_ctx=ChatContext() @@ -292,9 +364,90 @@ async def _request_fnc_call( .append(text=request, role="user"), fnc_ctx=fnc_ctx, temperature=temperature, + tool_choice=tool_choice, + parallel_tool_calls=parallel_tool_calls, ) async for _ in stream: pass return stream + + +_HEARTS_RGBA_PATH = Path(__file__).parent / "hearts.rgba" +with open(_HEARTS_RGBA_PATH, "rb") as f: + image_data = f.read() + + _HEARTS_IMAGE_VIDEO_FRAME = VideoFrame( + width=512, height=512, type=VideoBufferType.RGBA, data=image_data + ) + +_HEARTS_JPEG_PATH = Path(__file__).parent / "hearts.jpg" +with open(_HEARTS_JPEG_PATH, "rb") as f: + _HEARTS_IMAGE_DATA_URL = ( + f"data:image/jpeg;base64,{base64.b64encode(f.read()).decode()}" + ) + + +@pytest.mark.parametrize("llm_factory", LLMS) +async def test_chat_with_image_data_url(llm_factory: Callable[[], llm.LLM]): + input_llm = llm_factory() + + chat_ctx = ( + ChatContext() + .append( + text="You are an AI assistant that describes images in detail upon request.", + role="system", + ) + .append( + text="Describe this image", + images=[ + llm.ChatImage(image=_HEARTS_IMAGE_DATA_URL, inference_detail="low") + ], + role="user", + ) + ) + + stream = input_llm.chat(chat_ctx=chat_ctx) + text = "" + async for chunk in stream: + if not chunk.choices: + continue + + content = chunk.choices[0].delta.content + if content: + text += content + + assert "heart" in text.lower() + + +@pytest.mark.parametrize("llm_factory", LLMS) +async def test_chat_with_image_frame(llm_factory: Callable[[], llm.LLM]): + input_llm = llm_factory() + + chat_ctx = ( + ChatContext() + .append( + text="You are an AI assistant that describes images in detail upon request.", + role="system", + ) + .append( + text="Describe this image", + images=[ + llm.ChatImage(image=_HEARTS_IMAGE_VIDEO_FRAME, inference_detail="low") + ], + role="user", + ) + ) + + stream = input_llm.chat(chat_ctx=chat_ctx) + text = "" + async for chunk in stream: + if not chunk.choices: + continue + + content = chunk.choices[0].delta.content + if content: + text += content + + assert "heart" in text.lower() diff --git a/tests/test_stt.py b/tests/test_stt.py index a876942bd..d1f340b1e 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -4,34 +4,37 @@ import asyncio import time -from itertools import product +from typing import Callable import pytest from livekit import agents +from livekit.agents import stt from livekit.plugins import assemblyai, azure, deepgram, fal, google, openai, silero from .utils import make_test_speech, wer SAMPLE_RATES = [24000, 44100] # test multiple input sample rates WER_THRESHOLD = 0.2 -RECOGNIZE_STT = [ - lambda: deepgram.STT(), - lambda: google.STT(), - lambda: google.STT( - languages=["en-AU"], - model="chirp_2", - spoken_punctuation=False, - location="us-central1", +RECOGNIZE_STT: list[Callable[[], stt.STT]] = [ + pytest.param(lambda: deepgram.STT(), id="deepgram"), + pytest.param(lambda: google.STT(), id="google"), + pytest.param( + lambda: google.STT( + languages=["en-AU"], + model="chirp_2", + spoken_punctuation=False, + location="us-central1", + ), + id="google.chirp_2", ), - lambda: openai.STT(), - lambda: fal.WizperSTT(), + pytest.param(lambda: openai.STT(), id="openai"), + pytest.param(lambda: fal.WizperSTT(), id="fal"), ] @pytest.mark.usefixtures("job_process") -@pytest.mark.parametrize( - "stt_factory, sample_rate", product(RECOGNIZE_STT, SAMPLE_RATES) -) +@pytest.mark.parametrize("stt_factory", RECOGNIZE_STT) +@pytest.mark.parametrize("sample_rate", SAMPLE_RATES) async def test_recognize(stt_factory, sample_rate): async with stt_factory() as stt: frames, transcript = make_test_speech(sample_rate=sample_rate) @@ -47,23 +50,34 @@ async def test_recognize(stt_factory, sample_rate): STREAM_VAD = silero.VAD.load(min_silence_duration=0.75) -STREAM_STT = [ - lambda: assemblyai.STT(), - lambda: deepgram.STT(), - lambda: google.STT(), - lambda: agents.stt.StreamAdapter(stt=openai.STT(), vad=STREAM_VAD), - lambda: google.STT( - languages=["en-AU"], - model="chirp_2", - spoken_punctuation=False, - location="us-central1", +STREAM_STT: list[Callable[[], stt.STT]] = [ + pytest.param(lambda: assemblyai.STT(), id="assemblyai"), + pytest.param(lambda: deepgram.STT(), id="deepgram"), + pytest.param(lambda: google.STT(), id="google"), + pytest.param( + lambda: agents.stt.StreamAdapter(stt=openai.STT(), vad=STREAM_VAD), + id="openai.stream", ), - lambda: azure.STT(), + pytest.param( + lambda: agents.stt.StreamAdapter(stt=openai.STT.with_groq(), vad=STREAM_VAD), + id="openai.with_groq.stream", + ), + pytest.param( + lambda: google.STT( + languages=["en-AU"], + model="chirp_2", + spoken_punctuation=False, + location="us-central1", + ), + id="google.chirp_2", + ), + pytest.param(lambda: azure.STT(), id="azure"), ] @pytest.mark.usefixtures("job_process") -@pytest.mark.parametrize("stt_factory, sample_rate", product(STREAM_STT, SAMPLE_RATES)) +@pytest.mark.parametrize("stt_factory", STREAM_STT) +@pytest.mark.parametrize("sample_rate", SAMPLE_RATES) async def test_stream(stt_factory, sample_rate): stt = stt_factory() @@ -94,7 +108,13 @@ async def _stream_output(): continue if event.type == agents.stt.SpeechEventType.FINAL_TRANSCRIPT: + if text != "": + text += " " text += event.alternatives[0].text + # ensure STT is tagging languages correctly + language = event.alternatives[0].language + assert language is not None + assert language.lower().startswith("en") if event.type == agents.stt.SpeechEventType.END_OF_SPEECH: recv_start = False diff --git a/tests/test_stt_fallback.py b/tests/test_stt_fallback.py new file mode 100644 index 000000000..2f6ec8a74 --- /dev/null +++ b/tests/test_stt_fallback.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import asyncio + +import pytest +from livekit.agents import APIConnectionError, utils +from livekit.agents.stt import STT, AvailabilityChangedEvent, FallbackAdapter +from livekit.agents.utils.aio.channel import ChanEmpty + +from .fake_stt import FakeSTT + + +class FallbackAdapterTester(FallbackAdapter): + def __init__( + self, + stt: list[STT], + *, + attempt_timeout: float = 10.0, + max_retry_per_stt: int = 1, + retry_interval: float = 5, + ) -> None: + super().__init__( + stt, + attempt_timeout=attempt_timeout, + max_retry_per_stt=max_retry_per_stt, + retry_interval=retry_interval, + ) + + self.on("stt_availability_changed", self._on_stt_availability_changed) + + self._availability_changed_ch: dict[ + int, utils.aio.Chan[AvailabilityChangedEvent] + ] = {id(t): utils.aio.Chan[AvailabilityChangedEvent]() for t in stt} + + def _on_stt_availability_changed(self, ev: AvailabilityChangedEvent) -> None: + self._availability_changed_ch[id(ev.stt)].send_nowait(ev) + + def availability_changed_ch( + self, + tts: STT, + ) -> utils.aio.ChanReceiver[AvailabilityChangedEvent]: + return self._availability_changed_ch[id(tts)] + + +async def test_stt_fallback() -> None: + fake1 = FakeSTT(fake_exception=APIConnectionError("fake1 failed")) + fake2 = FakeSTT(fake_transcript="hello world") + + fallback_adapter = FallbackAdapterTester([fake1, fake2]) + ev = await fallback_adapter.recognize([]) + assert ev.alternatives[0].text == "hello world" + + assert fake1.recognize_ch.recv_nowait() + assert fake2.recognize_ch.recv_nowait() + + assert not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available + + fake2.update_options(fake_exception=APIConnectionError("fake2 failed")) + + with pytest.raises(APIConnectionError): + await fallback_adapter.recognize([]) + + assert not fallback_adapter.availability_changed_ch(fake2).recv_nowait().available + + await fallback_adapter.aclose() + + # stream + fake1 = FakeSTT(fake_exception=APIConnectionError("fake1 failed")) + fake2 = FakeSTT(fake_transcript="hello world") + + fallback_adapter = FallbackAdapterTester([fake1, fake2]) + + async with fallback_adapter.stream() as stream: + stream.end_input() + + last_alt = "" + + async for ev in stream: + last_alt = ev.alternatives[0].text + + assert last_alt == "hello world" + + await fallback_adapter.aclose() + + +async def test_stt_stream_fallback() -> None: + fake1 = FakeSTT(fake_exception=APIConnectionError("fake1 failed")) + fake2 = FakeSTT(fake_transcript="hello world") + + fallback_adapter = FallbackAdapterTester([fake1, fake2]) + + async with fallback_adapter.stream() as stream: + stream.end_input() + + async for _ in stream: + pass + + assert fake1.stream_ch.recv_nowait() + assert fake2.stream_ch.recv_nowait() + + assert not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available + + await fallback_adapter.aclose() + + +async def test_stt_recover() -> None: + fake1 = FakeSTT(fake_exception=APIConnectionError("fake1 failed")) + fake2 = FakeSTT(fake_exception=APIConnectionError("fake2 failed"), fake_timeout=0.5) + + fallback_adapter = FallbackAdapterTester([fake1, fake2]) + + with pytest.raises(APIConnectionError): + await fallback_adapter.recognize([]) + + fake2.update_options(fake_exception=None, fake_transcript="hello world") + + assert not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available + assert not fallback_adapter.availability_changed_ch(fake2).recv_nowait().available + + assert ( + await asyncio.wait_for( + fallback_adapter.availability_changed_ch(fake2).recv(), 1.0 + ) + ).available, "fake2 should have recovered" + + await fallback_adapter.recognize([]) + + assert fake1.recognize_ch.recv_nowait() + assert fake2.recognize_ch.recv_nowait() + + with pytest.raises(ChanEmpty): + fallback_adapter.availability_changed_ch(fake1).recv_nowait() + + with pytest.raises(ChanEmpty): + fallback_adapter.availability_changed_ch(fake2).recv_nowait() + + await fallback_adapter.aclose() diff --git a/tests/test_tts.py b/tests/test_tts.py index 378c87ba7..91f8035b5 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -4,12 +4,21 @@ """ import dataclasses +from typing import Callable import pytest from livekit import agents -from livekit.agents import APIConnectionError, tokenize +from livekit.agents import APIConnectionError, tokenize, tts from livekit.agents.utils import AudioBuffer, merge_frames -from livekit.plugins import azure, cartesia, elevenlabs, google, openai +from livekit.plugins import ( + azure, + cartesia, + deepgram, + elevenlabs, + google, + openai, + playai, +) from .conftest import TEST_CONNECT_OPTIONS from .fake_tts import FakeTTS @@ -33,13 +42,17 @@ async def _assert_valid_synthesized_audio( ), "num channels should be the same" -SYNTHESIZE_TTS = [ - lambda: elevenlabs.TTS(), - lambda: elevenlabs.TTS(encoding="pcm_44100"), - lambda: openai.TTS(), - lambda: google.TTS(), - lambda: azure.TTS(), - lambda: cartesia.TTS(), +SYNTHESIZE_TTS: list[Callable[[], tts.TTS]] = [ + pytest.param(lambda: elevenlabs.TTS(), id="elevenlabs"), + pytest.param( + lambda: elevenlabs.TTS(encoding="pcm_44100"), id="elevenlabs.pcm_44100" + ), + pytest.param(lambda: openai.TTS(), id="openai"), + pytest.param(lambda: google.TTS(), id="google"), + pytest.param(lambda: azure.TTS(), id="azure"), + pytest.param(lambda: cartesia.TTS(), id="cartesia"), + pytest.param(lambda: deepgram.TTS(), id="deepgram"), + pytest.param(lambda: playai.TTS(), id="playai"), ] @@ -60,26 +73,39 @@ async def test_synthesize(tts_factory): STREAM_SENT_TOKENIZER = tokenize.basic.SentenceTokenizer(min_sentence_len=20) -STREAM_TTS = [ - lambda: elevenlabs.TTS(), - lambda: elevenlabs.TTS(encoding="pcm_44100"), - lambda: cartesia.TTS(), - lambda: agents.tts.StreamAdapter( - tts=openai.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER +STREAM_TTS: list[Callable[[], tts.TTS]] = [ + pytest.param(lambda: elevenlabs.TTS(), id="elevenlabs"), + pytest.param( + lambda: elevenlabs.TTS(encoding="pcm_44100"), id="elevenlabs.pcm_44100" + ), + pytest.param(lambda: cartesia.TTS(), id="cartesia"), + pytest.param( + lambda: agents.tts.StreamAdapter( + tts=openai.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER + ), + id="openai.stream", ), - lambda: agents.tts.StreamAdapter( - tts=google.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER + pytest.param( + lambda: agents.tts.StreamAdapter( + tts=google.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER + ), + id="google.stream", ), - lambda: agents.tts.StreamAdapter( - tts=azure.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER + pytest.param( + lambda: agents.tts.StreamAdapter( + tts=azure.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER + ), + id="azure.stream", ), + pytest.param(lambda: deepgram.TTS(), id="deepgram"), + pytest.param(lambda: playai.TTS(), id="playai"), ] @pytest.mark.usefixtures("job_process") @pytest.mark.parametrize("tts_factory", STREAM_TTS) async def test_stream(tts_factory): - tts = tts_factory() + tts: agents.tts.TTS = tts_factory() synthesize_transcript = make_test_synthesize() @@ -96,21 +122,31 @@ async def test_stream(tts_factory): stream = tts.stream() + segments = set() + # for i in range(2): # TODO(theomonnom): we should test 2 segments for chunk in chunks: stream.push_text(chunk) stream.flush() + # if i == 1: stream.end_input() frames = [] + is_final = False async for audio in stream: + is_final = audio.is_final + segments.add(audio.segment_id) frames.append(audio.frame) - await stream.aclose() + assert is_final, "final audio should be marked as final" + await _assert_valid_synthesized_audio( frames, tts, synthesize_transcript, WER_THRESHOLD ) + # assert len(segments) == 2 + await stream.aclose() + async def test_retry(): fake_tts = FakeTTS(fake_exception=APIConnectionError("fake exception")) diff --git a/tests/test_fallback.py b/tests/test_tts_fallback.py similarity index 70% rename from tests/test_fallback.py rename to tests/test_tts_fallback.py index 76248f9f5..de7cf5c26 100644 --- a/tests/test_fallback.py +++ b/tests/test_tts_fallback.py @@ -1,11 +1,13 @@ from __future__ import annotations import asyncio +import contextlib import pytest from livekit import rtc from livekit.agents import APIConnectionError, utils from livekit.agents.tts import TTS, AvailabilityChangedEvent, FallbackAdapter +from livekit.agents.tts.tts import SynthesizeStream from livekit.agents.utils.aio.channel import ChanEmpty from .fake_tts import FakeTTS @@ -75,6 +77,46 @@ async def test_tts_fallback() -> None: await fallback_adapter.aclose() +async def test_no_audio() -> None: + fake1 = FakeTTS(fake_audio_duration=0.0) + + fallback_adapter = FallbackAdapterTester([fake1]) + + with pytest.raises(APIConnectionError): + async with fallback_adapter.synthesize("hello test") as stream: + async for _ in stream: + pass + + # stream + fake1.update_options(fake_audio_duration=5.0) + + async def _input_task(stream: SynthesizeStream): + with contextlib.suppress(RuntimeError): + stream.push_text("hello test") + stream.flush() + await asyncio.sleep(1.0) + + fake1.update_options(fake_timeout=0.5, fake_audio_duration=None) + + stream.push_text("hello test") + stream.end_input() + + with pytest.raises(APIConnectionError): + async with fallback_adapter.stream() as stream: + input_task = asyncio.create_task(_input_task(stream)) + + segments = set() + try: + async for ev in stream: + segments.add(ev.segment_id) + finally: + await input_task + + assert len(segments) == 1 + + await fallback_adapter.aclose() + + async def test_tts_stream_fallback() -> None: fake1 = FakeTTS(fake_exception=APIConnectionError("fake1 failed")) fake2 = FakeTTS(fake_audio_duration=5.0) @@ -169,6 +211,8 @@ async def test_audio_resampled() -> None: async for data in stream: frames.append(data.frame) + print(frames) + assert fake2.stream_ch.recv_nowait() combined_frame = rtc.combine_audio_frames(frames) @@ -197,6 +241,7 @@ async def test_timeout(): assert await asyncio.wait_for(fake1.synthesize_ch.recv(), 1.0) assert await asyncio.wait_for(fake2.synthesize_ch.recv(), 1.0) + # stream with pytest.raises(APIConnectionError): async with fallback_adapter.stream() as stream: stream.end_input() @@ -210,3 +255,55 @@ async def test_timeout(): assert await asyncio.wait_for(fake2.stream_ch.recv(), 1.0) await fallback_adapter.aclose() + + # consecutive push must not timeout + fake1.update_options(fake_timeout=None, fake_audio_duration=5.0) + fallback_adapter = FallbackAdapterTester([fake1], attempt_timeout=0.25) + + async def _input_task1(stream: SynthesizeStream): + stream.push_text("hello world") + stream.flush() + await asyncio.sleep(1.0) + + stream.push_text("bye world") + stream.end_input() + + async with fallback_adapter.stream() as stream: + input_task = asyncio.create_task(_input_task1(stream)) + + segments = set() + final_count = 0 + async for ev in stream: + segments.add(ev.segment_id) + if ev.is_final: + final_count += 1 + + assert len(segments) == 2 + assert final_count == 2 + await input_task + + async def _input_task2(stream: SynthesizeStream): + with contextlib.suppress(RuntimeError): + stream.push_text("hello test") + stream.flush() + await asyncio.sleep(1.0) + + fake1.update_options(fake_timeout=0.5, fake_audio_duration=None) + + stream.push_text("hello test") + stream.flush() + await asyncio.sleep(1.0) + + stream.end_input() + + with pytest.raises(APIConnectionError): + async with fallback_adapter.stream() as stream: + input_task = asyncio.create_task(_input_task2(stream)) + + try: + async for ev in stream: + pass + finally: + await input_task + + await fallback_adapter.aclose()
LiveKit Ecosystem
Realtime SDKsBrowser · iOS/macOS/visionOS · Android · Flutter · React Native · Rust · Node.js · Python · Unity · Unity (WebGL)