diff --git a/.changeset/smart-files-fly.md b/.changeset/smart-files-fly.md new file mode 100644 index 000000000..91e85ba1e --- /dev/null +++ b/.changeset/smart-files-fly.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": minor +--- + +Update pipeline agent to unlink human input on participant disconnect diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index e6f65e772..7379261b3 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -20,7 +20,9 @@ from .. import metrics, stt, tokenize, tts, utils, vad from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream +from ..stt import SpeechEvent from ..types import ATTRIBUTE_AGENT_STATE, AgentState +from ..vad import VADEvent from .agent_output import AgentOutput, SpeechSource, SynthesisHandle from .agent_playout import AgentPlayout from .human_input import HumanInput @@ -394,6 +396,7 @@ def _on_vad_metrics(vad_metrics: vad.VADMetrics) -> None: ) room.on("participant_connected", self._on_participant_connected) + room.on("participant_disconnected", self._on_participant_disconnected) self._room, self._participant = room, participant if participant is not None: @@ -537,6 +540,7 @@ async def aclose(self) -> None: return self._room.off("participant_connected", self._on_participant_connected) + self._room.off("participant_disconnected", self._on_participant_disconnected) await self._deferred_validation.aclose() def _on_participant_connected(self, participant: rtc.RemoteParticipant): @@ -545,98 +549,110 @@ def _on_participant_connected(self, participant: rtc.RemoteParticipant): self._link_participant(participant.identity) - def _link_participant(self, identity: str) -> None: - participant = self._room.remote_participants.get(identity) - if participant is None: - logger.error("_link_participant must be called with a valid identity") + def _on_participant_disconnected(self, participant: rtc.RemoteParticipant): + if self._human_input is None: return - self._human_input = HumanInput( - room=self._room, - vad=self._vad, - stt=self._stt, - participant=participant, - transcription=self._opts.transcription.user_transcription, - ) + self._unlink_participant() - def _on_start_of_speech(ev: vad.VADEvent) -> None: - self._plotter.plot_event("user_started_speaking") - self.emit("user_started_speaking") - self._deferred_validation.on_human_start_of_speech(ev) + def _on_start_of_speech(self, ev: VADEvent) -> None: + self._plotter.plot_event("user_started_speaking") + self.emit("user_started_speaking") + self._deferred_validation.on_human_start_of_speech(ev) - def _on_vad_inference_done(ev: vad.VADEvent) -> None: - if not self._track_published_fut.done(): - return + def _on_vad_inference_done(self, ev: VADEvent) -> None: + if not self._track_published_fut.done(): + return - assert self._agent_output is not None + assert self._agent_output is not None - tv = 1.0 - if self._opts.allow_interruptions: - tv = max(0.0, 1.0 - ev.probability) - self._agent_output.playout.target_volume = tv + tv = 1.0 + if self._opts.allow_interruptions: + tv = max(0.0, 1.0 - ev.probability) + self._agent_output.playout.target_volume = tv - smoothed_tv = self._agent_output.playout.smoothed_volume + smoothed_tv = self._agent_output.playout.smoothed_volume - self._plotter.plot_value("raw_vol", tv) - self._plotter.plot_value("smoothed_vol", smoothed_tv) - self._plotter.plot_value("vad_probability", ev.probability) + self._plotter.plot_value("raw_vol", tv) + self._plotter.plot_value("smoothed_vol", smoothed_tv) + self._plotter.plot_value("vad_probability", ev.probability) - if ev.speech_duration >= self._opts.int_speech_duration: - self._interrupt_if_possible() + if ev.speech_duration >= self._opts.int_speech_duration: + self._interrupt_if_possible() - if ev.raw_accumulated_speech > 0.0: - self._last_speech_time = ( - time.perf_counter() - ev.raw_accumulated_silence - ) + if ev.raw_accumulated_speech > 0.0: + self._last_speech_time = time.perf_counter() - ev.raw_accumulated_silence - def _on_end_of_speech(ev: vad.VADEvent) -> None: - self._plotter.plot_event("user_stopped_speaking") - self.emit("user_stopped_speaking") - self._deferred_validation.on_human_end_of_speech(ev) + def _on_end_of_speech(self, ev: VADEvent) -> None: + self._plotter.plot_event("user_stopped_speaking") + self.emit("user_stopped_speaking") + self._deferred_validation.on_human_end_of_speech(ev) - def _on_interim_transcript(ev: stt.SpeechEvent) -> None: - self._transcribed_interim_text = ev.alternatives[0].text + def _on_interim_transcript(self, ev: SpeechEvent) -> None: + self._transcribed_interim_text = ev.alternatives[0].text - def _on_final_transcript(ev: stt.SpeechEvent) -> None: - new_transcript = ev.alternatives[0].text - if not new_transcript: - return + def _on_final_transcript(self, ev: SpeechEvent) -> None: + new_transcript = ev.alternatives[0].text + if not new_transcript: + return - logger.debug( - "received user transcript", - extra={"user_transcript": new_transcript}, - ) + logger.debug( + "received user transcript", + extra={"user_transcript": new_transcript}, + ) - self._last_final_transcript_time = time.perf_counter() + self._last_final_transcript_time = time.perf_counter() - self._transcribed_text += ( - " " if self._transcribed_text else "" - ) + new_transcript + self._transcribed_text += ( + " " if self._transcribed_text else "" + ) + new_transcript - if self._opts.preemptive_synthesis: - if ( - self._playing_speech is None - or self._playing_speech.allow_interruptions - ): - self._synthesize_agent_reply() + if self._opts.preemptive_synthesis: + if self._playing_speech is None or self._playing_speech.allow_interruptions: + self._synthesize_agent_reply() self._deferred_validation.on_human_final_transcript( new_transcript, ev.alternatives[0].language ) - words = self._opts.transcription.word_tokenizer.tokenize( - text=new_transcript - ) - if len(words) >= 3: - # VAD can sometimes not detect that the human is speaking - # to make the interruption more reliable, we also interrupt on the final transcript. - self._interrupt_if_possible() + words = self._opts.transcription.word_tokenizer.tokenize(text=new_transcript) + if len(words) >= 3: + # VAD can sometimes not detect that the human is speaking + # to make the interruption more reliable, we also interrupt on the final transcript. + self._interrupt_if_possible() + + def _link_participant(self, identity: str) -> None: + participant = self._room.remote_participants.get(identity) + if participant is None: + logger.error("_link_participant must be called with a valid identity") + return + + self._human_input = HumanInput( + room=self._room, + vad=self._vad, + stt=self._stt, + participant=participant, + transcription=self._opts.transcription.user_transcription, + ) + + # Register event handlers using class methods + self._human_input.on("start_of_speech", self._on_start_of_speech) + self._human_input.on("vad_inference_done", self._on_vad_inference_done) + self._human_input.on("end_of_speech", self._on_end_of_speech) + self._human_input.on("interim_transcript", self._on_interim_transcript) + self._human_input.on("final_transcript", self._on_final_transcript) + + def _unlink_participant(self) -> None: + if self._human_input is None: + return - self._human_input.on("start_of_speech", _on_start_of_speech) - self._human_input.on("vad_inference_done", _on_vad_inference_done) - self._human_input.on("end_of_speech", _on_end_of_speech) - self._human_input.on("interim_transcript", _on_interim_transcript) - self._human_input.on("final_transcript", _on_final_transcript) + # Remove all event listeners using class methods + self._human_input.off("start_of_speech", self._on_start_of_speech) + self._human_input.off("vad_inference_done", self._on_vad_inference_done) + self._human_input.off("end_of_speech", self._on_end_of_speech) + self._human_input.off("interim_transcript", self._on_interim_transcript) + self._human_input.off("final_transcript", self._on_final_transcript) + self._human_input = None @utils.log_exceptions(logger=logger) async def _main_task(self) -> None: @@ -1265,14 +1281,14 @@ def on_human_final_transcript(self, transcript: str, language: str | None) -> No if delay is not None: self._run(delay) - def on_human_start_of_speech(self, ev: vad.VADEvent) -> None: + def on_human_start_of_speech(self, ev: VADEvent) -> None: self._speaking = True self._last_recv_start_of_speech_time = time.perf_counter() if self.validating: assert self._validating_task is not None self._validating_task.cancel() - def on_human_end_of_speech(self, ev: vad.VADEvent) -> None: + def on_human_end_of_speech(self, ev: VADEvent) -> None: self._speaking = False self._last_recv_end_of_speech_time = time.perf_counter()