Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add participant_disconnect and unlink methods, to reconnect human input #1125

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/smart-files-fly.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": minor
---

Update pipeline agent to unlink human input on participant disconnect
160 changes: 88 additions & 72 deletions livekit-agents/livekit/agents/pipeline/pipeline_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from .. import metrics, stt, tokenize, tts, utils, vad
from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream
from ..stt import SpeechEvent
from ..types import ATTRIBUTE_AGENT_STATE, AgentState
from ..vad import VADEvent
from .agent_output import AgentOutput, SpeechSource, SynthesisHandle
from .agent_playout import AgentPlayout
from .human_input import HumanInput
Expand Down Expand Up @@ -365,6 +367,7 @@ def _on_vad_metrics(vad_metrics: vad.VADMetrics) -> None:
)

room.on("participant_connected", self._on_participant_connected)
room.on("participant_disconnected", self._on_participant_disconnected)
self._room, self._participant = room, participant

if participant is not None:
Expand Down Expand Up @@ -446,6 +449,7 @@ async def aclose(self) -> None:
return

self._room.off("participant_connected", self._on_participant_connected)
self._room.off("participant_disconnected", self._on_participant_disconnected)
await self._deferred_validation.aclose()

def _on_participant_connected(self, participant: rtc.RemoteParticipant):
Expand All @@ -454,96 +458,108 @@ def _on_participant_connected(self, participant: rtc.RemoteParticipant):

self._link_participant(participant.identity)

def _link_participant(self, identity: str) -> None:
participant = self._room.remote_participants.get(identity)
if participant is None:
logger.error("_link_participant must be called with a valid identity")
def _on_participant_disconnected(self, participant: rtc.RemoteParticipant):
if self._human_input is None:
return

self._human_input = HumanInput(
room=self._room,
vad=self._vad,
stt=self._stt,
participant=participant,
transcription=self._opts.transcription.user_transcription,
)
self._unlink_participant()

def _on_start_of_speech(ev: vad.VADEvent) -> None:
self._plotter.plot_event("user_started_speaking")
self.emit("user_started_speaking")
self._deferred_validation.on_human_start_of_speech(ev)
def _on_start_of_speech(self, ev: VADEvent) -> None:
self._plotter.plot_event("user_started_speaking")
self.emit("user_started_speaking")
self._deferred_validation.on_human_start_of_speech(ev)

def _on_vad_inference_done(ev: vad.VADEvent) -> None:
if not self._track_published_fut.done():
return
def _on_vad_inference_done(self, ev: VADEvent) -> None:
if not self._track_published_fut.done():
return

assert self._agent_output is not None
assert self._agent_output is not None

tv = 1.0
if self._opts.allow_interruptions:
tv = max(0.0, 1.0 - ev.probability)
self._agent_output.playout.target_volume = tv
tv = 1.0
if self._opts.allow_interruptions:
tv = max(0.0, 1.0 - ev.probability)
self._agent_output.playout.target_volume = tv

smoothed_tv = self._agent_output.playout.smoothed_volume
smoothed_tv = self._agent_output.playout.smoothed_volume

self._plotter.plot_value("raw_vol", tv)
self._plotter.plot_value("smoothed_vol", smoothed_tv)
self._plotter.plot_value("vad_probability", ev.probability)
self._plotter.plot_value("raw_vol", tv)
self._plotter.plot_value("smoothed_vol", smoothed_tv)
self._plotter.plot_value("vad_probability", ev.probability)

if ev.speech_duration >= self._opts.int_speech_duration:
self._interrupt_if_possible()
if ev.speech_duration >= self._opts.int_speech_duration:
self._interrupt_if_possible()

if ev.raw_accumulated_speech > 0.0:
self._last_speech_time = (
time.perf_counter() - ev.raw_accumulated_silence
)
if ev.raw_accumulated_speech > 0.0:
self._last_speech_time = time.perf_counter() - ev.raw_accumulated_silence

def _on_end_of_speech(ev: vad.VADEvent) -> None:
self._plotter.plot_event("user_stopped_speaking")
self.emit("user_stopped_speaking")
self._deferred_validation.on_human_end_of_speech(ev)
def _on_end_of_speech(self, ev: VADEvent) -> None:
self._plotter.plot_event("user_stopped_speaking")
self.emit("user_stopped_speaking")
self._deferred_validation.on_human_end_of_speech(ev)

def _on_interim_transcript(ev: stt.SpeechEvent) -> None:
self._transcribed_interim_text = ev.alternatives[0].text
def _on_interim_transcript(self, ev: SpeechEvent) -> None:
self._transcribed_interim_text = ev.alternatives[0].text

def _on_final_transcript(ev: stt.SpeechEvent) -> None:
new_transcript = ev.alternatives[0].text
if not new_transcript:
return
def _on_final_transcript(self, ev: SpeechEvent) -> None:
new_transcript = ev.alternatives[0].text
if not new_transcript:
return

logger.debug(
"received user transcript",
extra={"user_transcript": new_transcript},
)
logger.debug(
"received user transcript",
extra={"user_transcript": new_transcript},
)

self._last_final_transcript_time = time.perf_counter()
self._last_final_transcript_time = time.perf_counter()

self._transcribed_text += (
" " if self._transcribed_text else ""
) + new_transcript
self._transcribed_text += (
" " if self._transcribed_text else ""
) + new_transcript

if self._opts.preemptive_synthesis:
if (
self._playing_speech is None
or self._playing_speech.allow_interruptions
):
self._synthesize_agent_reply()
if self._opts.preemptive_synthesis:
if self._playing_speech is None or self._playing_speech.allow_interruptions:
self._synthesize_agent_reply()

self._deferred_validation.on_human_final_transcript(new_transcript)
self._deferred_validation.on_human_final_transcript(new_transcript)

words = self._opts.transcription.word_tokenizer.tokenize(
text=new_transcript
)
if len(words) >= 3:
# VAD can sometimes not detect that the human is speaking
# to make the interruption more reliable, we also interrupt on the final transcript.
self._interrupt_if_possible()
words = self._opts.transcription.word_tokenizer.tokenize(text=new_transcript)
if len(words) >= 3:
# VAD can sometimes not detect that the human is speaking
# to make the interruption more reliable, we also interrupt on the final transcript.
self._interrupt_if_possible()

def _link_participant(self, identity: str) -> None:
participant = self._room.remote_participants.get(identity)
if participant is None:
logger.error("_link_participant must be called with a valid identity")
return

self._human_input = HumanInput(
room=self._room,
vad=self._vad,
stt=self._stt,
participant=participant,
transcription=self._opts.transcription.user_transcription,
)

# Register event handlers using class methods
self._human_input.on("start_of_speech", self._on_start_of_speech)
self._human_input.on("vad_inference_done", self._on_vad_inference_done)
self._human_input.on("end_of_speech", self._on_end_of_speech)
self._human_input.on("interim_transcript", self._on_interim_transcript)
self._human_input.on("final_transcript", self._on_final_transcript)

def _unlink_participant(self) -> None:
if self._human_input is None:
return

self._human_input.on("start_of_speech", _on_start_of_speech)
self._human_input.on("vad_inference_done", _on_vad_inference_done)
self._human_input.on("end_of_speech", _on_end_of_speech)
self._human_input.on("interim_transcript", _on_interim_transcript)
self._human_input.on("final_transcript", _on_final_transcript)
# Remove all event listeners using class methods
self._human_input.off("start_of_speech", self._on_start_of_speech)
self._human_input.off("vad_inference_done", self._on_vad_inference_done)
self._human_input.off("end_of_speech", self._on_end_of_speech)
self._human_input.off("interim_transcript", self._on_interim_transcript)
self._human_input.off("final_transcript", self._on_final_transcript)
self._human_input = None

@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
Expand Down Expand Up @@ -1046,13 +1062,13 @@ def on_human_final_transcript(self, transcript: str) -> None:

self._run(delay)

def on_human_start_of_speech(self, ev: vad.VADEvent) -> None:
def on_human_start_of_speech(self, ev: VADEvent) -> None:
self._speaking = True
if self.validating:
assert self._validating_task is not None
self._validating_task.cancel()

def on_human_end_of_speech(self, ev: vad.VADEvent) -> None:
def on_human_end_of_speech(self, ev: VADEvent) -> None:
self._speaking = False
self._last_recv_end_of_speech_time = time.time()

Expand Down
Loading