From efa39fd2494e4c2941a9b59c715ca002a96d4132 Mon Sep 17 00:00:00 2001 From: Neil Dwyer Date: Tue, 23 Jul 2024 13:06:19 -0700 Subject: [PATCH] Bump python sdk versions + fix types (#490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Théo Monnom --- examples/voice-assistant/minimal_assistant.py | 9 ++----- livekit-agents/livekit/agents/cli/watcher.py | 2 +- livekit-agents/livekit/agents/ipc/channel.py | 2 +- .../livekit/agents/ipc/proc_main.py | 2 +- .../livekit/agents/ipc/proc_pool.py | 5 ++-- .../livekit/agents/ipc/supervised_proc.py | 24 ++++++++++++------- livekit-agents/livekit/agents/job.py | 8 +++---- .../livekit/agents/transcription/_utils.py | 4 ++-- .../agents/transcription/stt_forwarder.py | 8 ++++++- .../agents/transcription/tts_forwarder.py | 14 +++++++++-- .../agents/voice_assistant/agent_output.py | 3 ++- .../voice_assistant/cancellable_source.py | 2 +- .../agents/voice_assistant/human_input.py | 2 +- .../agents/voice_assistant/voice_assistant.py | 4 ++-- livekit-agents/livekit/agents/worker.py | 6 +++-- livekit-agents/setup.py | 6 ++--- .../livekit/plugins/elevenlabs/tts.py | 4 ++-- .../livekit/plugins/google/stt.py | 2 +- .../livekit/plugins/silero/vad.py | 10 ++++++-- 19 files changed, 71 insertions(+), 46 deletions(-) diff --git a/examples/voice-assistant/minimal_assistant.py b/examples/voice-assistant/minimal_assistant.py index 338830856..d8158d4d0 100644 --- a/examples/voice-assistant/minimal_assistant.py +++ b/examples/voice-assistant/minimal_assistant.py @@ -1,13 +1,11 @@ import asyncio -import logging -from livekit.agents import JobContext, JobProcess, JobRequest, WorkerOptions, cli +from livekit.agents import JobContext, WorkerOptions, cli from livekit.agents.llm import ChatContext from livekit.agents.voice_assistant import VoiceAssistant from livekit.plugins import deepgram, openai, silero - async def entrypoint(ctx: JobContext): initial_ctx = ChatContext().append( role="system", @@ -32,8 +30,5 @@ async def entrypoint(ctx: JobContext): await assistant.say("Hey, how can I help you today?", allow_interruptions=True) - if __name__ == "__main__": - cli.run_app( - WorkerOptions(entrypoint_fnc=entrypoint) - ) + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint)) diff --git a/livekit-agents/livekit/agents/cli/watcher.py b/livekit-agents/livekit/agents/cli/watcher.py index 073c67c37..b80bfc294 100644 --- a/livekit-agents/livekit/agents/cli/watcher.py +++ b/livekit-agents/livekit/agents/cli/watcher.py @@ -75,7 +75,7 @@ def __init__( self._main_file = main_file self._loop = loop - self._recv_jobs_fut = asyncio.Future() + self._recv_jobs_fut = asyncio.Future[None]() self._reloading_jobs = False async def run(self) -> None: diff --git a/livekit-agents/livekit/agents/ipc/channel.py b/livekit-agents/livekit/agents/ipc/channel.py index d44830925..dc577b483 100644 --- a/livekit-agents/livekit/agents/ipc/channel.py +++ b/livekit-agents/livekit/agents/ipc/channel.py @@ -105,7 +105,7 @@ def __init__( self._read_q = asyncio.Queue[Optional[Message]]() self._write_q = queue.Queue[Optional[Message]]() - self._exit_fut = asyncio.Future() + self._exit_fut = asyncio.Future[None]() self._read_t = threading.Thread( target=self._read_thread, daemon=True, name="proc_channel_read" diff --git a/livekit-agents/livekit/agents/ipc/proc_main.py b/livekit-agents/livekit/agents/ipc/proc_main.py index 1eabb0fba..80aed18cc 100644 --- a/livekit-agents/livekit/agents/ipc/proc_main.py +++ b/livekit-agents/livekit/agents/ipc/proc_main.py @@ -212,6 +212,6 @@ def main(args: proto.ProcStartArgs) -> None: finally: try: loop.run_until_complete(loop.shutdown_default_executor()) - #loop.run_until_complete(cch.aclose()) + # loop.run_until_complete(cch.aclose()) finally: loop.close() diff --git a/livekit-agents/livekit/agents/ipc/proc_pool.py b/livekit-agents/livekit/agents/ipc/proc_pool.py index 1d835924f..0ab6b61a8 100644 --- a/livekit-agents/livekit/agents/ipc/proc_pool.py +++ b/livekit-agents/livekit/agents/ipc/proc_pool.py @@ -2,7 +2,6 @@ import asyncio import multiprocessing as mp -import sys from typing import Any, Callable, Coroutine, Literal from .. import utils @@ -30,9 +29,9 @@ def __init__( ) -> None: super().__init__() - #if sys.platform.startswith("linux"): + # if sys.platform.startswith("linux"): # self._mp_ctx = mp.get_context("forkserver") - #else: + # else: self._mp_ctx = mp.get_context("spawn") self._initialize_process_fnc = initialize_process_fnc diff --git a/livekit-agents/livekit/agents/ipc/supervised_proc.py b/livekit-agents/livekit/agents/ipc/supervised_proc.py index 993fed3e6..6c3da8555 100644 --- a/livekit-agents/livekit/agents/ipc/supervised_proc.py +++ b/livekit-agents/livekit/agents/ipc/supervised_proc.py @@ -32,6 +32,8 @@ def start(self) -> None: t.start() def stop(self) -> None: + if self._thread is None: + return self._q.put_nowait(self._sentinel) self._thread.join() self._thread = None @@ -67,7 +69,7 @@ def __init__( loop: asyncio.AbstractEventLoop, ) -> None: self._loop = loop - log_q = mp.Queue() + log_q = mp.Queue[logging.LogRecord]() log_q.cancel_join_thread() mp_pch, mp_cch = mp_ctx.Pipe(duplex=True) @@ -97,7 +99,7 @@ def __init__( self._main_atask: asyncio.Task[None] | None = None self._closing = False self._kill_sent = False - self._initialize_fut = asyncio.Future() + self._initialize_fut = asyncio.Future[None]() @property def exitcode(self) -> int | None: @@ -145,7 +147,7 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None: self._proc.start() self._pid = self._proc.pid - self._join_fut = asyncio.Future() + self._join_fut = asyncio.Future[None]() def _sync_run(): self._proc.join() @@ -161,7 +163,8 @@ async def join(self) -> None: if not self.started: raise RuntimeError("process not started") - await asyncio.shield(self._main_atask) + if self._main_atask: + await asyncio.shield(self._main_atask) async def initialize(self) -> None: """initialize the job process, this is calling the user provided initialize_process_fnc @@ -198,16 +201,18 @@ async def aclose(self) -> None: await self._pch.asend(proto.ShutdownRequest()) try: - await asyncio.wait_for( - asyncio.shield(self._main_atask), timeout=self._close_timeout - ) + if self._main_atask: + await asyncio.wait_for( + asyncio.shield(self._main_atask), timeout=self._close_timeout + ) except asyncio.TimeoutError: logger.error( "process did not exit in time, killing job", extra=self.logging_extra() ) self._send_kill_signal() - await asyncio.shield(self._main_atask) + if self._main_atask: + await asyncio.shield(self._main_atask) async def kill(self) -> None: """forcefully kill the job process""" @@ -216,7 +221,8 @@ async def kill(self) -> None: self._closing = True self._send_kill_signal() - await asyncio.shield(self._main_atask) + if self._main_atask: + await asyncio.shield(self._main_atask) async def launch_job(self, info: RunningJobInfo) -> None: """start/assign a job to the process""" diff --git a/livekit-agents/livekit/agents/job.py b/livekit-agents/livekit/agents/job.py index e7bf5a863..7b285805b 100644 --- a/livekit-agents/livekit/agents/job.py +++ b/livekit-agents/livekit/agents/job.py @@ -114,8 +114,8 @@ def _subscribe_if_needed(pub: rtc.RemoteTrackPublication): ): pub.set_subscribed(True) - for p in room.participants.values(): - for pub in p.tracks.values(): + for p in room.remote_participants.values(): + for pub in p.track_publications.values(): _subscribe_if_needed(pub) @room.on("track_published") @@ -128,11 +128,11 @@ async def on_track_published( class JobProcess: def __init__(self, *, start_arguments: Any | None = None) -> None: self._mp_proc = mp.current_process() - self._userdata = {} + self._userdata: dict[str, Any] = {} self._start_arguments = start_arguments @property - def pid(self) -> int: + def pid(self) -> int | None: return self._mp_proc.pid @property diff --git a/livekit-agents/livekit/agents/transcription/_utils.py b/livekit-agents/livekit/agents/transcription/_utils.py index 86db31a84..dc839f2e6 100644 --- a/livekit-agents/livekit/agents/transcription/_utils.py +++ b/livekit-agents/livekit/agents/transcription/_utils.py @@ -7,7 +7,7 @@ def find_micro_track_id(room: rtc.Room, identity: str) -> str: p: rtc.RemoteParticipant | rtc.LocalParticipant | None = ( - room.participants_by_identity.get(identity) + room.remote_participants.get(identity) ) if identity == room.local_participant.identity: p = room.local_participant @@ -17,7 +17,7 @@ def find_micro_track_id(room: rtc.Room, identity: str) -> str: # find first micro track track_id = None - for track in p.tracks.values(): + for track in p.track_publications.values(): if track.source == rtc.TrackSource.SOURCE_MICROPHONE: track_id = track.sid break diff --git a/livekit-agents/livekit/agents/transcription/stt_forwarder.py b/livekit-agents/livekit/agents/transcription/stt_forwarder.py index 1c4227eb6..2d8b566c2 100644 --- a/livekit-agents/livekit/agents/transcription/stt_forwarder.py +++ b/livekit-agents/livekit/agents/transcription/stt_forwarder.py @@ -66,13 +66,19 @@ def update(self, ev: stt.SpeechEvent): start_time=0, end_time=0, final=False, + language="", # TODO ) ) elif ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: text = ev.alternatives[0].text self._queue.put_nowait( rtc.TranscriptionSegment( - id=self._current_id, text=text, start_time=0, end_time=0, final=True + id=self._current_id, + text=text, + start_time=0, + end_time=0, + final=True, + language="", # TODO ) ) diff --git a/livekit-agents/livekit/agents/transcription/tts_forwarder.py b/livekit-agents/livekit/agents/transcription/tts_forwarder.py index 76be4622c..9ea5956de 100644 --- a/livekit-agents/livekit/agents/transcription/tts_forwarder.py +++ b/livekit-agents/livekit/agents/transcription/tts_forwarder.py @@ -294,7 +294,12 @@ async def _sync_sentence_co( await self._sleep_if_not_closed(first_delay) rtc_seg_q.put_nowait( rtc.TranscriptionSegment( - id=seg_id, text=text, start_time=0, end_time=0, final=False + id=seg_id, + text=text, + start_time=0, + end_time=0, + final=False, + language=self._opts.language, ) ) await self._sleep_if_not_closed(delay - first_delay) @@ -302,7 +307,12 @@ async def _sync_sentence_co( rtc_seg_q.put_nowait( rtc.TranscriptionSegment( - id=seg_id, text=tokenized_sentence, start_time=0, end_time=0, final=True + id=seg_id, + text=tokenized_sentence, + start_time=0, + end_time=0, + final=True, + language=self._opts.language, ) ) diff --git a/livekit-agents/livekit/agents/voice_assistant/agent_output.py b/livekit-agents/livekit/agents/voice_assistant/agent_output.py index 4bc005180..ba7c65773 100644 --- a/livekit-agents/livekit/agents/voice_assistant/agent_output.py +++ b/livekit-agents/livekit/agents/voice_assistant/agent_output.py @@ -130,7 +130,8 @@ async def _synthesize_task(self, handle: SynthesisHandle) -> None: if handle.play_handle is not None: await handle.play_handle finally: - await handle._tr_fwd.aclose() + if handle._tr_fwd: + await handle._tr_fwd.aclose() @utils.log_exceptions(logger=logger) diff --git a/livekit-agents/livekit/agents/voice_assistant/cancellable_source.py b/livekit-agents/livekit/agents/voice_assistant/cancellable_source.py index ad6bb3989..b9f2c3fd4 100644 --- a/livekit-agents/livekit/agents/voice_assistant/cancellable_source.py +++ b/livekit-agents/livekit/agents/voice_assistant/cancellable_source.py @@ -154,4 +154,4 @@ def _should_break(): self.emit("playout_stopped", cancelled) - handle._done_fut.set_result(None) \ No newline at end of file + handle._done_fut.set_result(None) diff --git a/livekit-agents/livekit/agents/voice_assistant/human_input.py b/livekit-agents/livekit/agents/voice_assistant/human_input.py index 5ffc1286b..3e36b50d4 100644 --- a/livekit-agents/livekit/agents/voice_assistant/human_input.py +++ b/livekit-agents/livekit/agents/voice_assistant/human_input.py @@ -75,7 +75,7 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None: Subscribe to the participant microphone if found and not already subscribed. Do nothing if no track is found. """ - for publication in self._participant.tracks.values(): + for publication in self._participant.track_publications.values(): if publication.source != rtc.TrackSource.SOURCE_MICROPHONE: continue diff --git a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py index 085745938..eb07a53f3 100644 --- a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py +++ b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py @@ -216,7 +216,7 @@ def start( self._link_participant(participant) else: # no participant provided, try to find the first in the room - for participant in self._room.participants.values(): + for participant in self._room.remote_participants.values(): self._link_participant(participant.identity) break @@ -290,7 +290,7 @@ def _on_participant_connected(self, participant: rtc.RemoteParticipant): self._link_participant(participant.identity) def _link_participant(self, identity: str) -> None: - participant = self._room.participants_by_identity.get(identity) + participant = self._room.remote_participants.get(identity) if participant is None: logger.error("_link_participant must be called with a valid identity") return diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index a9f17435a..96455e9c9 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -35,7 +35,7 @@ from .log import DEV_LEVEL, logger from .version import __version__ -MAX_RECONNECT_ATTEMPTS = 3.0 +MAX_RECONNECT_ATTEMPTS = 3 ASSIGNMENT_TIMEOUT = 7.5 UPDATE_LOAD_INTERVAL = 10.0 @@ -43,9 +43,11 @@ def _default_initialize_process_fnc(proc: JobProcess) -> Any: return + async def _default_shutdown_fnc(proc: JobContext) -> None: return + async def _default_request_fnc(ctx: JobRequest) -> None: await ctx.accept() @@ -240,7 +242,7 @@ async def _queue_msg(self, msg: agent.WorkerMessage) -> None: """_queue_msg raises aio.ChanClosed when the worker is closing/closed""" if self._connecting: which = msg.WhichOneof("message") - if which == "update_worker" and not msg.update_worker.metadata: + if which == "update_worker": return elif which == "ping": return diff --git a/livekit-agents/setup.py b/livekit-agents/setup.py index d3146d431..6f608a409 100644 --- a/livekit-agents/setup.py +++ b/livekit-agents/setup.py @@ -48,9 +48,9 @@ python_requires=">=3.9.0", install_requires=[ "click~=8.1", - "livekit~=0.11", - "livekit-api~=0.4", - "livekit-protocol~=0.4", + "livekit~=0.12.0.dev0", + "livekit-api~=0.6.0", + "livekit-protocol~=0.6.0", "protobuf>=3", "pyjwt>=2.0.0", "types-protobuf>=4,<5", diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py index b1cf7ed9d..9aaa658b4 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/tts.py @@ -183,8 +183,8 @@ async def _main_task(self) -> None: headers={AUTHORIZATION_HEADER: self._opts.api_key}, json=data, ) as resp: - async for data, _ in resp.content.iter_chunks(): - for frame in bstream.write(data): + async for bytes_data, _ in resp.content.iter_chunks(): + for frame in bstream.write(bytes_data): self._event_ch.send_nowait( tts.SynthesizedAudio( request_id=request_id, segment_id=segment_id, frame=frame diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py index 47e047e32..0efc3664b 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py @@ -131,8 +131,8 @@ def _sanitize_options(self, *, language: str | None = None) -> STTOptions: async def recognize( self, - *, buffer: utils.AudioBuffer, + *, language: SpeechLanguages | str | None = None, ) -> stt.SpeechEvent: config = self._sanitize_options(language=language) diff --git a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py index b4cdf3daa..100ea6fc9 100644 --- a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py +++ b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio +import math import time from dataclasses import dataclass @@ -177,8 +178,12 @@ async def _run_inference(self, window_rx: utils.aio.ChanReceiver[_WindowData]): may_start_at_sample = -1 may_end_at_sample = -1 - min_speech_samples = self._opts.min_speech_duration * self._opts.sample_rate - min_silence_samples = self._opts.min_silence_duration * self._opts.sample_rate + min_speech_samples = int( + self._opts.min_speech_duration * self._opts.sample_rate + ) + min_silence_samples = int( + self._opts.min_silence_duration * self._opts.sample_rate + ) current_sample = 0 @@ -204,6 +209,7 @@ async def _run_inference(self, window_rx: utils.aio.ChanReceiver[_WindowData]): else: max_data_s += self._opts.max_buffered_speech + assert self._original_sample_rate is not None cl = int(max_data_s) * self._original_sample_rate if len(pub_speech_buf) > cl: pub_speech_buf = pub_speech_buf[-cl:]