Skip to content

Commit

Permalink
Merge branch 'dev' of https://github.com/livekit/agents into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom committed Jul 23, 2024
2 parents 37019b5 + efa39fd commit c6810bd
Show file tree
Hide file tree
Showing 17 changed files with 63 additions and 38 deletions.
3 changes: 1 addition & 2 deletions examples/voice-assistant/minimal_assistant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging

from livekit.agents import JobContext, JobProcess, JobRequest, WorkerOptions, cli
from livekit.agents import JobContext, WorkerOptions, cli
from livekit.agents.llm import ChatContext
from livekit.agents.voice_assistant import VoiceAssistant
from livekit.plugins import deepgram, openai, silero
Expand Down
2 changes: 1 addition & 1 deletion livekit-agents/livekit/agents/cli/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
self._main_file = main_file
self._loop = loop

self._recv_jobs_fut = asyncio.Future()
self._recv_jobs_fut = asyncio.Future[None]()
self._reloading_jobs = False

async def run(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion livekit-agents/livekit/agents/ipc/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(

self._read_q = asyncio.Queue[Optional[Message]]()
self._write_q = queue.Queue[Optional[Message]]()
self._exit_fut = asyncio.Future()
self._exit_fut = asyncio.Future[None]()

self._read_t = threading.Thread(
target=self._read_thread, daemon=True, name="proc_channel_read"
Expand Down
1 change: 0 additions & 1 deletion livekit-agents/livekit/agents/ipc/proc_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import multiprocessing as mp
import sys
from typing import Any, Callable, Coroutine, Literal

from .. import utils
Expand Down
24 changes: 15 additions & 9 deletions livekit-agents/livekit/agents/ipc/supervised_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def start(self) -> None:
t.start()

def stop(self) -> None:
if self._thread is None:
return
self._q.put_nowait(self._sentinel)
self._thread.join()
self._thread = None
Expand Down Expand Up @@ -67,7 +69,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
) -> None:
self._loop = loop
log_q = mp.Queue()
log_q = mp.Queue[logging.LogRecord]()
log_q.cancel_join_thread()
mp_pch, mp_cch = mp_ctx.Pipe(duplex=True)

Expand Down Expand Up @@ -97,7 +99,7 @@ def __init__(
self._main_atask: asyncio.Task[None] | None = None
self._closing = False
self._kill_sent = False
self._initialize_fut = asyncio.Future()
self._initialize_fut = asyncio.Future[None]()

@property
def exitcode(self) -> int | None:
Expand Down Expand Up @@ -145,7 +147,7 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None:

self._proc.start()
self._pid = self._proc.pid
self._join_fut = asyncio.Future()
self._join_fut = asyncio.Future[None]()

def _sync_run():
self._proc.join()
Expand All @@ -161,7 +163,8 @@ async def join(self) -> None:
if not self.started:
raise RuntimeError("process not started")

await asyncio.shield(self._main_atask)
if self._main_atask:
await asyncio.shield(self._main_atask)

async def initialize(self) -> None:
"""initialize the job process, this is calling the user provided initialize_process_fnc
Expand Down Expand Up @@ -198,16 +201,18 @@ async def aclose(self) -> None:
await self._pch.asend(proto.ShutdownRequest())

try:
await asyncio.wait_for(
asyncio.shield(self._main_atask), timeout=self._close_timeout
)
if self._main_atask:
await asyncio.wait_for(
asyncio.shield(self._main_atask), timeout=self._close_timeout
)
except asyncio.TimeoutError:
logger.error(
"process did not exit in time, killing job", extra=self.logging_extra()
)
self._send_kill_signal()

await asyncio.shield(self._main_atask)
if self._main_atask:
await asyncio.shield(self._main_atask)

async def kill(self) -> None:
"""forcefully kill the job process"""
Expand All @@ -216,7 +221,8 @@ async def kill(self) -> None:

self._closing = True
self._send_kill_signal()
await asyncio.shield(self._main_atask)
if self._main_atask:
await asyncio.shield(self._main_atask)

async def launch_job(self, info: RunningJobInfo) -> None:
"""start/assign a job to the process"""
Expand Down
8 changes: 4 additions & 4 deletions livekit-agents/livekit/agents/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def _subscribe_if_needed(pub: rtc.RemoteTrackPublication):
):
pub.set_subscribed(True)

for p in room.participants.values():
for pub in p.tracks.values():
for p in room.remote_participants.values():
for pub in p.track_publications.values():
_subscribe_if_needed(pub)

@room.on("track_published")
Expand All @@ -128,11 +128,11 @@ async def on_track_published(
class JobProcess:
def __init__(self, *, start_arguments: Any | None = None) -> None:
self._mp_proc = mp.current_process()
self._userdata = {}
self._userdata: dict[str, Any] = {}
self._start_arguments = start_arguments

@property
def pid(self) -> int:
def pid(self) -> int | None:
return self._mp_proc.pid

@property
Expand Down
4 changes: 2 additions & 2 deletions livekit-agents/livekit/agents/transcription/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def find_micro_track_id(room: rtc.Room, identity: str) -> str:
p: rtc.RemoteParticipant | rtc.LocalParticipant | None = (
room.participants_by_identity.get(identity)
room.remote_participants.get(identity)
)
if identity == room.local_participant.identity:
p = room.local_participant
Expand All @@ -17,7 +17,7 @@ def find_micro_track_id(room: rtc.Room, identity: str) -> str:

# find first micro track
track_id = None
for track in p.tracks.values():
for track in p.track_publications.values():
if track.source == rtc.TrackSource.SOURCE_MICROPHONE:
track_id = track.sid
break
Expand Down
8 changes: 7 additions & 1 deletion livekit-agents/livekit/agents/transcription/stt_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,19 @@ def update(self, ev: stt.SpeechEvent):
start_time=0,
end_time=0,
final=False,
language="", # TODO
)
)
elif ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT:
text = ev.alternatives[0].text
self._queue.put_nowait(
rtc.TranscriptionSegment(
id=self._current_id, text=text, start_time=0, end_time=0, final=True
id=self._current_id,
text=text,
start_time=0,
end_time=0,
final=True,
language="", # TODO
)
)

Expand Down
14 changes: 12 additions & 2 deletions livekit-agents/livekit/agents/transcription/tts_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,25 @@ async def _sync_sentence_co(
await self._sleep_if_not_closed(first_delay)
rtc_seg_q.put_nowait(
rtc.TranscriptionSegment(
id=seg_id, text=text, start_time=0, end_time=0, final=False
id=seg_id,
text=text,
start_time=0,
end_time=0,
final=False,
language=self._opts.language,
)
)
await self._sleep_if_not_closed(delay - first_delay)
seg.processed_hyphenes += word_hyphenes

rtc_seg_q.put_nowait(
rtc.TranscriptionSegment(
id=seg_id, text=tokenized_sentence, start_time=0, end_time=0, final=True
id=seg_id,
text=tokenized_sentence,
start_time=0,
end_time=0,
final=True,
language=self._opts.language,
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import contextlib
from typing import AsyncIterable, Literal

from livekit import rtc
Expand Down Expand Up @@ -155,4 +154,4 @@ def _should_break():
self.emit("playout_stopped", cancelled)

handle._done_fut.set_result(None)
logger.debug("CancellableAudioSource._playout_task: ended")
logger.debug("CancellableAudioSource._playout_task: ended")
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None:
Subscribe to the participant microphone if found and not already subscribed.
Do nothing if no track is found.
"""
for publication in self._participant.tracks.values():
for publication in self._participant.track_publications.values():
if publication.source != rtc.TrackSource.SOURCE_MICROPHONE:
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def start(
self._link_participant(participant)
else:
# no participant provided, try to find the first in the room
for participant in self._room.participants.values():
for participant in self._room.remote_participants.values():
self._link_participant(participant.identity)
break

Expand Down Expand Up @@ -290,7 +290,7 @@ def _on_participant_connected(self, participant: rtc.RemoteParticipant):
self._link_participant(participant.identity)

def _link_participant(self, identity: str) -> None:
participant = self._room.participants_by_identity.get(identity)
participant = self._room.remote_participants.get(identity)
if participant is None:
logger.error("_link_participant must be called with a valid identity")
return
Expand Down
4 changes: 2 additions & 2 deletions livekit-agents/livekit/agents/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .log import DEV_LEVEL, logger
from .version import __version__

MAX_RECONNECT_ATTEMPTS = 3.0
MAX_RECONNECT_ATTEMPTS = 3
ASSIGNMENT_TIMEOUT = 7.5
UPDATE_LOAD_INTERVAL = 10.0

Expand Down Expand Up @@ -242,7 +242,7 @@ async def _queue_msg(self, msg: agent.WorkerMessage) -> None:
"""_queue_msg raises aio.ChanClosed when the worker is closing/closed"""
if self._connecting:
which = msg.WhichOneof("message")
if which == "update_worker" and not msg.update_worker.metadata:
if which == "update_worker":
return
elif which == "ping":
return
Expand Down
6 changes: 3 additions & 3 deletions livekit-agents/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
python_requires=">=3.9.0",
install_requires=[
"click~=8.1",
"livekit~=0.11",
"livekit-api~=0.4",
"livekit-protocol~=0.4",
"livekit~=0.12.0.dev0",
"livekit-api~=0.6.0",
"livekit-protocol~=0.6.0",
"protobuf>=3",
"pyjwt>=2.0.0",
"types-protobuf>=4,<5",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ async def _main_task(self) -> None:
headers={AUTHORIZATION_HEADER: self._opts.api_key},
json=data,
) as resp:
async for data, _ in resp.content.iter_chunks():
for frame in bstream.write(data):
async for bytes_data, _ in resp.content.iter_chunks():
for frame in bstream.write(bytes_data):
self._event_ch.send_nowait(
tts.SynthesizedAudio(
request_id=request_id, segment_id=segment_id, frame=frame
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def _sanitize_options(self, *, language: str | None = None) -> STTOptions:

async def recognize(
self,
*,
buffer: utils.AudioBuffer,
*,
language: SpeechLanguages | str | None = None,
) -> stt.SpeechEvent:
config = self._sanitize_options(language=language)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import asyncio
import math

Check failure on line 18 in livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py:18:8: F401 `math` imported but unused
import time
from dataclasses import dataclass

Expand Down Expand Up @@ -177,8 +178,12 @@ async def _run_inference(self, window_rx: utils.aio.ChanReceiver[_WindowData]):
may_start_at_sample = -1
may_end_at_sample = -1

min_speech_samples = self._opts.min_speech_duration * self._opts.sample_rate
min_silence_samples = self._opts.min_silence_duration * self._opts.sample_rate
min_speech_samples = int(
self._opts.min_speech_duration * self._opts.sample_rate
)
min_silence_samples = int(
self._opts.min_silence_duration * self._opts.sample_rate
)

current_sample = 0

Expand All @@ -204,6 +209,7 @@ async def _run_inference(self, window_rx: utils.aio.ChanReceiver[_WindowData]):
else:
max_data_s += self._opts.max_buffered_speech

assert self._original_sample_rate is not None
cl = int(max_data_s) * self._original_sample_rate
if len(pub_speech_buf) > cl:
pub_speech_buf = pub_speech_buf[-cl:]
Expand Down

0 comments on commit c6810bd

Please sign in to comment.