Skip to content

Commit

Permalink
update the plugins to use the new API & fix tests (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Jul 23, 2024
1 parent af2fa0c commit b6c486d
Show file tree
Hide file tree
Showing 47 changed files with 1,110 additions and 1,330 deletions.
16 changes: 11 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ on:
jobs:
tests:
name: Run tests
runs-on: ubuntu-latest
runs-on: namespace-profile-4vcpu-cache
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: true
lfs: true
Expand All @@ -33,9 +33,15 @@ jobs:
run: |
pip3 install pytest pytest-asyncio pytest-timeout './livekit-agents[codecs]' psutil
pip3 install -r ./tests/test-requirements.txt
for dir in livekit-plugins/*; do
pip3 install $dir
done
pip3 install ./livekit-agents \
./livekit-plugins/livekit-plugins-openai \
./livekit-plugins/livekit-plugins-deepgram \
./livekit-plugins/livekit-plugins-google \
./livekit-plugins/livekit-plugins-nltk \
./livekit-plugins/livekit-plugins-silero \
./livekit-plugins/livekit-plugins-elevenlabs \
./livekit-plugins/livekit-plugins-cartesia \
./livekit-plugins/livekit-plugins-azure
- name: Run tests
env:
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

node_modules

credentials.json
pyrightconfig.json
9 changes: 7 additions & 2 deletions examples/voice-assistant/minimal_assistant.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import asyncio
import logging

from livekit.agents import JobContext, JobRequest, WorkerOptions, cli
from livekit.agents import JobContext, JobProcess, JobRequest, WorkerOptions, cli
from livekit.agents.llm import ChatContext
from livekit.agents.voice_assistant import VoiceAssistant
from livekit.plugins import deepgram, openai, silero


def initialize(proc: JobProcess):
proc.userdata["silero"] = silero.VAD.load()


async def entrypoint(ctx: JobContext):
silero_vad: silero.VAD = ctx.proc.userdata["silero"]
initial_ctx = ChatContext().append(
role="system",
text=(
Expand All @@ -19,7 +24,7 @@ async def entrypoint(ctx: JobContext):
await ctx.connect()

assistant = VoiceAssistant(
vad=silero.VAD(),
vad=silero_vad,
stt=deepgram.STT(),
llm=openai.LLM(),
tts=openai.TTS(),
Expand Down
2 changes: 1 addition & 1 deletion livekit-agents/livekit/agents/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import click
from livekit.protocol import models

from . import proto
from .. import utils
from ..log import logger
from ..plugin import Plugin
from ..worker import Worker, WorkerOptions
from . import proto
from .log import setup_logging


Expand Down
2 changes: 1 addition & 1 deletion livekit-agents/livekit/agents/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import date, datetime, time, timezone
from inspect import istraceback
from typing import Any, Dict, Tuple

from ..log import logger

# skip default LogRecord attributes
Expand Down Expand Up @@ -198,4 +199,3 @@ def setup_logging(log_level: str, production: bool = True) -> None:
root.setLevel(logging.WARN)

logger.setLevel(log_level)

3 changes: 1 addition & 2 deletions livekit-agents/livekit/agents/cli/proto.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

import io
import pickle
from dataclasses import dataclass, field
from typing import ClassVar

from livekit.protocol import agent

from ..job import RunningJobInfo, JobAcceptArguments
from ..ipc import channel
from ..job import JobAcceptArguments, RunningJobInfo
from ..worker import WorkerOptions


Expand Down
17 changes: 9 additions & 8 deletions livekit-agents/livekit/agents/cli/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .. import utils
from ..ipc import channel
from ..log import logger, DEV_LEVEL
from ..log import DEV_LEVEL, logger
from ..plugin import Plugin
from ..worker import Worker
from . import proto
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
) -> None:
mp_pch, cli_args.mp_cch = multiprocessing.Pipe(duplex=True)
self._pch = channel.ProcChannel(
self._pch = channel.AsyncProcChannel(
conn=mp_pch, loop=loop, messages=proto.IPC_MESSAGES
)
self._cli_args = cli_args
Expand Down Expand Up @@ -103,9 +103,8 @@ async def _on_reload(self, _: Set[watchfiles.main.FileChange]) -> None:

self._recv_jobs_fut = asyncio.Future()
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(
self._recv_jobs_fut, timeout=1.5
) # wait max 1.5s to get the active jobs
# wait max 1.5s to get the active jobs
await asyncio.wait_for(self._recv_jobs_fut, timeout=1.5)

@utils.log_exceptions(logger=logger)
async def _read_ipc_task(self) -> None:
Expand All @@ -131,7 +130,7 @@ def __init__(
) -> None:
self._loop = loop or asyncio.get_event_loop()
self._worker = worker
self._cch = channel.ProcChannel(
self._cch = channel.AsyncProcChannel(
conn=mp_cch, loop=self._loop, messages=proto.IPC_MESSAGES
)

Expand All @@ -142,7 +141,10 @@ def start(self) -> None:
async def _run(self) -> None:
await self._cch.asend(proto.ReloadJobsRequest())
while True:
msg = await self._cch.arecv()
try:
msg = await self._cch.arecv()
except channel.ChannelClosed:
break

if isinstance(msg, proto.ActiveJobsRequest):
jobs = self._worker.active_jobs
Expand All @@ -152,7 +154,6 @@ async def _run(self) -> None:
await self._worker._reload_jobs(msg.jobs)
await self._cch.asend(proto.Reloaded())


async def aclose(self) -> None:
if not self._main_task:
return
Expand Down
139 changes: 85 additions & 54 deletions livekit-agents/livekit/agents/ipc/channel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import asyncio
import threading
import queue
import contextlib
import io
import queue
import struct
import contextlib
from typing import ClassVar, Protocol, Type, runtime_checkable, Optional
import threading
from typing import ClassVar, Optional, Protocol, runtime_checkable


class ProcessConn(Protocol):
Expand Down Expand Up @@ -46,14 +46,63 @@ def __init__(
self,
*,
conn: ProcessConn,
loop: asyncio.AbstractEventLoop,
messages: dict[int, type[Message]],
) -> None:
self._loop = loop
self._conn = conn
self._messages = messages
self._closed = False

def recv(self) -> Message:
if self._closed:
raise ChannelClosed()

try:
b = io.BytesIO(self._conn.recv_bytes())
except (OSError, EOFError, ValueError):
raise ChannelClosed()

msg_id = read_int(b)
msg = self._messages[msg_id]()

if isinstance(msg, DataMessage):
msg.read(b)

return msg

def send(self, msg: Message) -> None:
if self._closed:
raise ChannelClosed()

b = io.BytesIO()
write_int(b, msg.MSG_ID)

if isinstance(msg, DataMessage):
msg.write(b)

try:
self._conn.send_bytes(b.getvalue())
except (OSError, ValueError):
raise ChannelClosed()

def close(self) -> None:
if self._closed:
return

self._closed = True
self._conn.close()


class AsyncProcChannel(ProcChannel):
def __init__(
self,
*,
conn: ProcessConn,
messages: dict[int, type[Message]],
loop: asyncio.AbstractEventLoop,
):
super().__init__(conn=conn, messages=messages)
self._loop = loop

self._read_q = asyncio.Queue[Optional[Message]]()
self._write_q = queue.Queue[Optional[Message]]()
self._exit_fut = asyncio.Future()
Expand All @@ -66,23 +115,38 @@ def __init__(
)
self._read_t.start()
self._write_t.start()
self._closed = False

def _read_thread(self) -> None:
while True:
try:
b = io.BytesIO(self._conn.recv_bytes())
except (OSError, EOFError):
break
async def arecv(self) -> Message:
if self._closed:
raise ChannelClosed()

msg = await self._read_q.get()
if msg is self._sentinel:
raise ChannelClosed()

return msg

async def asend(self, msg: Message) -> None:
if self._closed:
raise ChannelClosed()

msg_id = read_int(b)
msg = self._messages[msg_id]()
self._write_q.put_nowait(msg)

if isinstance(msg, DataMessage):
msg.read(b)
async def aclose(self) -> None:
self.close()
await self._exit_fut

def _read_thread(self) -> None:
while True:
try:
self._loop.call_soon_threadsafe(self._read_q.put_nowait, msg)
except RuntimeError:
if self._conn.poll(1.0):
msg = self.recv()
try:
self._loop.call_soon_threadsafe(self._read_q.put_nowait, msg)
except RuntimeError:
break
except ChannelClosed:
break

with contextlib.suppress(RuntimeError):
Expand All @@ -91,7 +155,7 @@ def _close():
self._exit_fut.set_result(None)
self._read_q.put_nowait(self._sentinel)
self._write_q.put_nowait(self._sentinel)
self._do_close()
self.close()

self._loop.call_soon_threadsafe(_close)

Expand All @@ -101,44 +165,11 @@ def _write_thread(self) -> None:
if msg is self._sentinel:
break

b = io.BytesIO()
write_int(b, msg.MSG_ID)

if isinstance(msg, DataMessage):
msg.write(b)

try:
self._conn.send_bytes(b.getvalue())
except (OSError, ValueError):
self.send(msg)
except ChannelClosed:
break

async def arecv(self) -> Message:
if self._closed:
raise ChannelClosed()

msg = await self._read_q.get()
if msg is self._sentinel:
raise ChannelClosed()

return msg

async def asend(self, msg: Message) -> None:
if self._closed:
raise ChannelClosed()

self._write_q.put_nowait(msg)

async def aclose(self) -> None:
self._do_close()
await self._exit_fut

def _do_close(self) -> None:
if self._closed:
return

self._closed = True
self._conn.close()


def write_bytes(b: io.BytesIO, buf: bytes) -> None:
b.write(len(buf).to_bytes(4, "big"))
Expand Down
Loading

0 comments on commit b6c486d

Please sign in to comment.