Skip to content

Commit

Permalink
fix tr_fwd closed errors
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom committed Jul 23, 2024
1 parent 7692b9d commit 37019b5
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 72 deletions.
6 changes: 1 addition & 5 deletions examples/voice-assistant/minimal_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from livekit.plugins import deepgram, openai, silero



async def entrypoint(ctx: JobContext):
initial_ctx = ChatContext().append(
role="system",
Expand All @@ -32,8 +31,5 @@ async def entrypoint(ctx: JobContext):
await assistant.say("Hey, how can I help you today?", allow_interruptions=True)



if __name__ == "__main__":
cli.run_app(
WorkerOptions(entrypoint_fnc=entrypoint)
)
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
11 changes: 6 additions & 5 deletions livekit-agents/livekit/agents/ipc/proc_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ def main(args: proto.ProcStartArgs) -> None:
# (this signal can be sent by watchfiles on dev mode)
loop.run_until_complete(main_task)
finally:
try:
loop.run_until_complete(loop.shutdown_default_executor())
#loop.run_until_complete(cch.aclose())
finally:
loop.close()
# try:
loop.run_until_complete(loop.shutdown_default_executor())
loop.run_until_complete(cch.aclose())
# finally:
# loop.close()
# pass
4 changes: 2 additions & 2 deletions livekit-agents/livekit/agents/ipc/proc_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def __init__(
) -> None:
super().__init__()

#if sys.platform.startswith("linux"):
# if sys.platform.startswith("linux"):
# self._mp_ctx = mp.get_context("forkserver")
#else:
# else:
self._mp_ctx = mp.get_context("spawn")

self._initialize_process_fnc = initialize_process_fnc
Expand Down
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/transcription/tts_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ def mark_text_segment_end(self) -> None:
self._forming_segments.q.append(new_seg)
self._seg_queue.put_nowait(new_seg)

@property
def closed(self) -> bool:
return self._closed

async def aclose(self) -> None:
if self._closed:
return
Expand Down
70 changes: 26 additions & 44 deletions livekit-agents/livekit/agents/voice_assistant/agent_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import time
import contextlib
from typing import Any, AsyncIterable, Union

Expand Down Expand Up @@ -122,48 +123,36 @@ async def _synthesize_task(self, handle: SynthesisHandle) -> None:
[synth, handle._interrupt_fut], return_when=asyncio.FIRST_COMPLETED
)
finally:
with contextlib.suppress(asyncio.CancelledError):
synth.cancel()
await synth

try:
if handle.play_handle is not None:
await handle.play_handle
finally:
await handle._tr_fwd.aclose()
await utils.aio.gracefully_cancel(synth)


@utils.log_exceptions(logger=logger)
async def _str_synthesis_task(text: str, handle: SynthesisHandle) -> None:
"""synthesize speech from a string"""
if handle._tr_fwd is not None:
if handle._tr_fwd and not handle._tr_fwd.closed:
handle._tr_fwd.push_text(text)
handle._tr_fwd.mark_text_segment_end()

# start_time = time.time()
# first_frame = True
# audio_duration = 0.0
start_time = time.time()
first_frame = True
handle._collected_text = text

try:
async for audio in handle._tts.synthesize(text):
# if first_frame:
# first_frame = False
# dt = time.time() - start_time
# self._log_debug(f"tts first frame in {dt:.2f}s")
if first_frame:
first_frame = False
dt = time.time() - start_time
logger.debug(f"AgentOutput._str_synthesis_task: TTFB in {dt:.2f}s")

frame = audio.frame
# audio_duration += frame.samples_per_channel / frame.sample_rate

handle._buf_ch.send_nowait(frame)
if handle._tr_fwd is not None:
if handle._tr_fwd and not handle._tr_fwd.closed:
handle._tr_fwd.push_audio(frame)

finally:
if handle._tr_fwd is not None:
if handle._tr_fwd and not handle._tr_fwd.closed:
handle._tr_fwd.mark_audio_segment_end()
handle._buf_ch.close()
# self._log_debug(f"tts finished synthesising {audio_duration:.2f}s of audio")


@utils.log_exceptions(logger=logger)
Expand All @@ -174,48 +163,41 @@ async def _stream_synthesis_task(

@utils.log_exceptions(logger=logger)
async def _read_generated_audio_task():
# start_time = time.time()
# first_frame = True
# audio_duration = 0.0
start_time = time.time()
first_frame = True
async for audio in tts_stream:
# if first_frame:
# first_frame = False
# dt = time.time() - start_time
# self._log_debug(f"tts first frame in {dt:.2f}s (streamed)")
if first_frame:
first_frame = False
dt = time.time() - start_time
logger.debug(f"AgentOutput._stream_synthesis_task: TTFB in {dt:.2f}s")

# audio_duration += frame.samples_per_channel / frame.sample_rate
if handle._tr_fwd is not None:
if handle._tr_fwd and not handle._tr_fwd.closed:
handle._tr_fwd.push_audio(audio.frame)

handle._buf_ch.send_nowait(audio.frame)

# we're only flushing once, so we know we can break at the end of the first segment

# self._log_debug(
# f"tts finished synthesising {audio_duration:.2f}s audio (streamed)"
# )

# otherwise, stream the text to the TTS
tts_stream = handle._tts.stream()
read_atask = asyncio.create_task(_read_generated_audio_task())

try:
async for seg in streamed_text:
handle._collected_text += seg
if handle._tr_fwd is not None:

if handle._tr_fwd and not handle._tr_fwd.closed:
handle._tr_fwd.push_text(seg)

tts_stream.push_text(seg)

finally:
if handle._tr_fwd is not None:
tts_stream.end_input()

if handle._tr_fwd and not handle._tr_fwd.closed:
handle._tr_fwd.mark_text_segment_end()

tts_stream.end_input()
await read_atask
await tts_stream.aclose()

if handle._tr_fwd is not None:
if handle._tr_fwd and not handle._tr_fwd.closed:
# mark_audio_segment_end must be called *after* mart_text_segment_end
handle._tr_fwd.mark_audio_segment_end()


await handle._tr_fwd.aclose()
21 changes: 11 additions & 10 deletions livekit-agents/livekit/agents/voice_assistant/cancellable_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def interrupt(self) -> None:

self._interrupted = True

def __await__(self):
return self._done_fut.__await__()
def join(self) -> asyncio.Future:
return self._done_fut


class CancellableAudioSource(utils.EventEmitter[EventTypes]):
Expand Down Expand Up @@ -89,6 +89,7 @@ def play(
self._playout_atask = asyncio.create_task(
self._playout_task(self._playout_atask, handle)
)

return handle

@utils.log_exceptions(logger=logger)
Expand All @@ -104,9 +105,9 @@ def _should_break():

try:
if old_task is not None:
with contextlib.suppress(asyncio.CancelledError):
old_task.cancel()
await old_task
await utils.aio.gracefully_cancel(old_task)

logger.debug("CancellableAudioSource._playout_task: started")

async for frame in handle._playout_source:
if first_frame:
Expand All @@ -121,7 +122,7 @@ def _should_break():
break

# divide the frame by chunks of 20ms
ms20 = frame.sample_rate // 100
ms20 = frame.sample_rate // 50
i = 0
while i < len(frame.data):
if _should_break():
Expand All @@ -148,10 +149,10 @@ def _should_break():
handle._time_played += rem / frame.sample_rate
finally:
if not first_frame:
if handle._tr_fwd is not None:
if not cancelled:
handle._tr_fwd.segment_playout_finished()
if handle._tr_fwd is not None and not cancelled:
handle._tr_fwd.segment_playout_finished()

self.emit("playout_stopped", cancelled)

handle._done_fut.set_result(None)
handle._done_fut.set_result(None)
logger.debug("CancellableAudioSource._playout_task: ended")
17 changes: 11 additions & 6 deletions livekit-agents/livekit/agents/voice_assistant/voice_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _on_vad_updated(ev: vad.VADEvent) -> None:

tv = 1.0
if self._opts.allow_interruptions:
tv = max(0, 1 - ev.probability)
tv = max(0.0, 1.0 - ev.probability)
self._agent_output.audio_source.target_volume = tv

smoothed_tv = self._agent_output.audio_source.smoothed_volume
Expand Down Expand Up @@ -471,6 +471,7 @@ async def _synthesize_answer_task(old_task: asyncio.Task[None]) -> None:
)

async def _play_speech(self, speech_info: _SpeechInfo) -> None:
logger.debug("VoiceAssistant._play_speech started")
MIN_TIME_PLAYED_FOR_COMMIT = 1.5

assert (
Expand All @@ -485,7 +486,7 @@ async def _play_speech(self, speech_info: _SpeechInfo) -> None:
user_speech_commited = False

play_handle = synthesis_handle.play()
play_handle_fut = asyncio.ensure_future(play_handle)
join_fut = play_handle.join()
self._playing_synthesis = synthesis_handle

def _commit_user_message_if_needed() -> None:
Expand All @@ -507,7 +508,7 @@ def _commit_user_message_if_needed() -> None:
# really quickly (barely audible), we don't want to mark this question as "answered".
if not is_using_tools and (
play_handle.time_played < MIN_TIME_PLAYED_FOR_COMMIT
and not play_handle_fut.done()
and not join_fut.done()
):
return

Expand All @@ -519,9 +520,9 @@ def _commit_user_message_if_needed() -> None:
user_speech_commited = True

# wait for the play_handle to finish and check every 1s if the user question should be committed
while not play_handle_fut.done():
while not join_fut.done():
await asyncio.wait(
[play_handle_fut], return_when=asyncio.FIRST_COMPLETED, timeout=1.0
[join_fut], return_when=asyncio.FIRST_COMPLETED, timeout=1.0
)

_commit_user_message_if_needed()
Expand Down Expand Up @@ -579,7 +580,8 @@ def _commit_user_message_if_needed() -> None:
transcript=_llm_stream_to_str_iterable(answer_stream)
)
self._playing_synthesis = answer_synthesis
await answer_synthesis.play()
play_handle = answer_synthesis.play()
await play_handle.join()

collected_text = answer_synthesis.collected_text
interrupted = answer_synthesis.interrupted
Expand All @@ -595,6 +597,8 @@ def _commit_user_message_if_needed() -> None:
else:
self.emit("agent_speech_committed", msg)

logger.debug("VoiceAssistant._play_speech ended")


async def _llm_stream_to_str_iterable(stream: LLMStream) -> AsyncIterable[str]:
async for chunk in stream:
Expand Down Expand Up @@ -677,6 +681,7 @@ async def _run_task(self, delay: float) -> None:
self._last_final_transcript = ""
self._received_end_of_speech = False
self._validate_fnc()
logger.debug("_DeferredAnswerValidation speech validated")

def _run(self, delay: float) -> None:
if self._validating_task is not None:
Expand Down
2 changes: 2 additions & 0 deletions livekit-agents/livekit/agents/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
def _default_initialize_process_fnc(proc: JobProcess) -> Any:
return


async def _default_shutdown_fnc(proc: JobContext) -> None:
return


async def _default_request_fnc(ctx: JobRequest) -> None:
await ctx.accept()

Expand Down

0 comments on commit 37019b5

Please sign in to comment.