Skip to content

Commit

Permalink
Implement basic SDK for convai agents (#389)
Browse files Browse the repository at this point in the history
* Implement basic SDK for convai agents

Early prototype, subject to change based on user feedback.

Takes care of the websocket session and message handling, exposing a
simplified audio interface to the client that can be hooked up to
the appropriate audio inputs / outputs based on the usecase.

Also implements a basic speaker/microphone interface, via optional
dependency on pyaudio.

* Move to `conversational_ai/` and split default_audio_interface

* Review fixes
  • Loading branch information
lacop11 authored Oct 25, 2024
1 parent 5a2c536 commit 9a45a3f
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 1 deletion.
1 change: 1 addition & 0 deletions .fernignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Specify files that shouldn't be modified by Fern

src/elevenlabs/client.py
src/elevenlabs/conversation.py
src/elevenlabs/play.py
src/elevenlabs/realtime_tts.py

Expand Down
39 changes: 38 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,21 @@ requests = ">=2.20"
typing_extensions = ">= 4.0.0"
websockets = ">=11.0"

# Optional extras.
pyaudio = { version = ">=0.2.14", optional = true }

[tool.poetry.dev-dependencies]
mypy = "1.0.1"
pytest = "^7.4.0"
pytest-asyncio = "^0.23.5"
python-dateutil = "^2.9.0"
types-pyaudio = "^0.2.16.20240516"
types-python-dateutil = "^2.9.0.20240316"
ruff = "^0.5.6"

[tool.poetry.extras]
pyaudio = ["pyaudio"]

[tool.pytest.ini_options]
testpaths = [ "tests" ]
asyncio_mode = "auto"
Expand Down
215 changes: 215 additions & 0 deletions src/elevenlabs/conversational_ai/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from abc import ABC, abstractmethod
import base64
import json
import threading
from typing import Callable, Optional

from websockets.sync.client import connect

from ..base_client import BaseElevenLabs


class AudioInterface(ABC):
"""AudioInterface provides an abstraction for handling audio input and output."""

@abstractmethod
def start(self, input_callback: Callable[[bytes], None]):
"""Starts the audio interface.
Called one time before the conversation starts.
The `input_callback` should be called regularly with input audio chunks from
the user. The audio should be in 16-bit PCM mono format at 16kHz. Recommended
chunk size is 4000 samples (250 milliseconds).
"""
pass

@abstractmethod
def stop(self):
"""Stops the audio interface.
Called one time after the conversation ends. Should clean up any resources
used by the audio interface and stop any audio streams. Do not call the
`input_callback` from `start` after this method is called.
"""
pass

@abstractmethod
def output(self, audio: bytes):
"""Output audio to the user.
The `audio` input is in 16-bit PCM mono format at 16kHz. Implementations can
choose to do additional buffering. This method should return quickly and not
block the calling thread.
"""
pass

@abstractmethod
def interrupt(self):
"""Interruption signal to stop any audio output.
User has interrupted the agent and all previosly buffered audio output should
be stopped.
"""
pass


class Conversation:
client: BaseElevenLabs
agent_id: str
requires_auth: bool

audio_interface: AudioInterface
callback_agent_response: Optional[Callable[[str], None]]
callback_agent_response_correction: Optional[Callable[[str, str], None]]
callback_user_transcript: Optional[Callable[[str], None]]
callback_latency_measurement: Optional[Callable[[int], None]]

_thread: Optional[threading.Thread] = None
_should_stop: threading.Event = threading.Event()
_conversation_id: Optional[str] = None
_last_interrupt_id: int = 0

def __init__(
self,
client: BaseElevenLabs,
agent_id: str,
*,
requires_auth: bool,
audio_interface: AudioInterface,
callback_agent_response: Optional[Callable[[str], None]] = None,
callback_agent_response_correction: Optional[Callable[[str, str], None]] = None,
callback_user_transcript: Optional[Callable[[str], None]] = None,
callback_latency_measurement: Optional[Callable[[int], None]] = None,
):
"""Conversational AI session.
BETA: This API is subject to change without regard to backwards compatibility.
Args:
client: The ElevenLabs client to use for the conversation.
agent_id: The ID of the agent to converse with.
requires_auth: Whether the agent requires authentication.
audio_interface: The audio interface to use for input and output.
callback_agent_response: Callback for agent responses.
callback_agent_response_correction: Callback for agent response corrections.
First argument is the original response (previously given to
callback_agent_response), second argument is the corrected response.
callback_user_transcript: Callback for user transcripts.
callback_latency_measurement: Callback for latency measurements (in milliseconds).
"""

self.client = client
self.agent_id = agent_id
self.requires_auth = requires_auth

self.audio_interface = audio_interface
self.callback_agent_response = callback_agent_response
self.callback_agent_response_correction = callback_agent_response_correction
self.callback_user_transcript = callback_user_transcript
self.callback_latency_measurement = callback_latency_measurement

def start_session(self):
"""Starts the conversation session.
Will run in background thread until `end_session` is called.
"""
ws_url = self._get_signed_url() if self.requires_auth else self._get_wss_url()
self._thread = threading.Thread(target=self._run, args=(ws_url,))
self._thread.start()

def end_session(self):
"""Ends the conversation session."""
self.audio_interface.stop()
self._should_stop.set()

def wait_for_session_end(self) -> Optional[str]:
"""Waits for the conversation session to end.
You must call `end_session` before calling this method, otherwise it will block.
Returns the conversation ID, if available.
"""
if not self._thread:
raise RuntimeError("Session not started.")
self._thread.join()
return self._conversation_id

def _run(self, ws_url: str):
with connect(ws_url) as ws:

def input_callback(audio):
ws.send(
json.dumps(
{
"user_audio_chunk": base64.b64encode(audio).decode(),
}
)
)

self.audio_interface.start(input_callback)
while not self._should_stop.is_set():
try:
message = json.loads(ws.recv(timeout=0.5))
if self._should_stop.is_set():
return
self._handle_message(message, ws)
except TimeoutError:
pass

def _handle_message(self, message, ws):
if message["type"] == "conversation_initiation_metadata":
event = message["conversation_initiation_metadata_event"]
assert self._conversation_id is None
self._conversation_id = event["conversation_id"]
elif message["type"] == "audio":
event = message["audio_event"]
if int(event["event_id"]) <= self._last_interrupt_id:
return
audio = base64.b64decode(event["audio_base_64"])
self.audio_interface.output(audio)
elif message["type"] == "agent_response":
if self.callback_agent_response:
event = message["agent_response_event"]
self.callback_agent_response(event["agent_response"].strip())
elif message["type"] == "agent_response_correction":
if self.callback_agent_response_correction:
event = message["agent_response_correction_event"]
self.callback_agent_response_correction(
event["original_agent_response"].strip(), event["corrected_agent_response"].strip()
)
elif message["type"] == "user_transcript":
if self.callback_user_transcript:
event = message["user_transcription_event"]
self.callback_user_transcript(event["user_transcript"].strip())
elif message["type"] == "interruption":
event = message["interruption_event"]
self.last_interrupt_id = int(event["event_id"])
self.audio_interface.interrupt()
elif message["type"] == "ping":
event = message["ping_event"]
ws.send(
json.dumps(
{
"type": "pong",
"event_id": event["event_id"],
}
)
)
if self.callback_latency_measurement and event["ping_ms"]:
self.callback_latency_measurement(int(event["ping_ms"]))
else:
pass # Ignore all other message types.

def _get_wss_url(self):
base_url = self.client._client_wrapper._base_url
# Replace http(s) with ws(s).
base_ws_url = base_url.replace("http", "ws", 1) # First occurrence only.
return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}"

def _get_signed_url(self):
# TODO: Use generated SDK method once available.
response = self.client._client_wrapper.httpx_client.request(
f"v1/convai/conversation/get_signed_url?agent_id={self.agent_id}",
method="GET",
)
return response.json()["signed_url"]
83 changes: 83 additions & 0 deletions src/elevenlabs/conversational_ai/default_audio_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Callable
import queue
import threading

from .conversation import AudioInterface


class DefaultAudioInterface(AudioInterface):
INPUT_FRAMES_PER_BUFFER = 4000 # 250ms @ 16kHz
OUTPUT_FRAMES_PER_BUFFER = 1000 # 62.5ms @ 16kHz

def __init__(self):
try:
import pyaudio
except ImportError:
raise ImportError("To use DefaultAudioInterface you must install pyaudio.")
self.pyaudio = pyaudio

def start(self, input_callback: Callable[[bytes], None]):
# Audio input is using callbacks from pyaudio which we simply pass through.
self.input_callback = input_callback

# Audio output is buffered so we can handle interruptions.
# Start a separate thread to handle writing to the output stream.
self.output_queue: queue.Queue[bytes] = queue.Queue()
self.should_stop = threading.Event()
self.output_thread = threading.Thread(target=self._output_thread)

self.p = self.pyaudio.PyAudio()
self.in_stream = self.p.open(
format=self.pyaudio.paInt16,
channels=1,
rate=16000,
input=True,
stream_callback=self._in_callback,
frames_per_buffer=self.INPUT_FRAMES_PER_BUFFER,
start=True,
)
self.out_stream = self.p.open(
format=self.pyaudio.paInt16,
channels=1,
rate=16000,
output=True,
frames_per_buffer=self.OUTPUT_FRAMES_PER_BUFFER,
start=True,
)

self.output_thread.start()

def stop(self):
self.should_stop.set()
self.output_thread.join()
self.in_stream.stop_stream()
self.in_stream.close()
self.out_stream.close()
self.p.terminate()

def output(self, audio: bytes):
self.output_queue.put(audio)

def interrupt(self):
# Clear the output queue to stop any audio that is currently playing.
# Note: We can't atomically clear the whole queue, but we are doing
# it from the message handling thread so no new audio will be added
# while we are clearing.
try:
while True:
_ = self.output_queue.get(block=False)
except queue.Empty:
pass

def _output_thread(self):
while not self.should_stop.is_set():
try:
audio = self.output_queue.get(timeout=0.25)
self.out_stream.write(audio)
except queue.Empty:
pass

def _in_callback(self, in_data, frame_count, time_info, status):
if self.input_callback:
self.input_callback(in_data)
return (None, self.pyaudio.paContinue)

0 comments on commit 9a45a3f

Please sign in to comment.