From ecf4db4506d29f40729908855c15e00c867119ef Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Fri, 22 Nov 2024 13:25:49 +1100 Subject: [PATCH 1/5] Add participant_disconnect and unlink methods for human input --- .../livekit/agents/pipeline/pipeline_agent.py | 164 ++++++++++-------- 1 file changed, 93 insertions(+), 71 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 6f647063b..97dada718 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -365,6 +365,7 @@ def _on_vad_metrics(vad_metrics: vad.VADMetrics) -> None: ) room.on("participant_connected", self._on_participant_connected) + room.on("participant_disconnected", self._on_participant_disconnected) self._room, self._participant = room, participant if participant is not None: @@ -446,6 +447,7 @@ async def aclose(self) -> None: return self._room.off("participant_connected", self._on_participant_connected) + self._room.off("participant_disconnected", self._on_participant_disconnected) await self._deferred_validation.aclose() def _on_participant_connected(self, participant: rtc.RemoteParticipant): @@ -454,96 +456,116 @@ def _on_participant_connected(self, participant: rtc.RemoteParticipant): self._link_participant(participant.identity) - def _link_participant(self, identity: str) -> None: - participant = self._room.remote_participants.get(identity) - if participant is None: - logger.error("_link_participant must be called with a valid identity") + def _on_participant_disconnected(self, participant: rtc.RemoteParticipant): + if self._human_input is None: return - self._human_input = HumanInput( - room=self._room, - vad=self._vad, - stt=self._stt, - participant=participant, - transcription=self._opts.transcription.user_transcription, - ) + self._unlink_participant() - def _on_start_of_speech(ev: vad.VADEvent) -> None: - self._plotter.plot_event("user_started_speaking") - self.emit("user_started_speaking") - self._deferred_validation.on_human_start_of_speech(ev) + def _on_start_of_speech(self, ev: vad.VADEvent) -> None: + self._plotter.plot_event("user_started_speaking") + self.emit("user_started_speaking") + self._deferred_validation.on_human_start_of_speech(ev) - def _on_vad_inference_done(ev: vad.VADEvent) -> None: - if not self._track_published_fut.done(): - return + def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: + if not self._track_published_fut.done(): + return - assert self._agent_output is not None + assert self._agent_output is not None - tv = 1.0 - if self._opts.allow_interruptions: - tv = max(0.0, 1.0 - ev.probability) - self._agent_output.playout.target_volume = tv + tv = 1.0 + if self._opts.allow_interruptions: + tv = max(0.0, 1.0 - ev.probability) + self._agent_output.playout.target_volume = tv - smoothed_tv = self._agent_output.playout.smoothed_volume + smoothed_tv = self._agent_output.playout.smoothed_volume - self._plotter.plot_value("raw_vol", tv) - self._plotter.plot_value("smoothed_vol", smoothed_tv) - self._plotter.plot_value("vad_probability", ev.probability) + self._plotter.plot_value("raw_vol", tv) + self._plotter.plot_value("smoothed_vol", smoothed_tv) + self._plotter.plot_value("vad_probability", ev.probability) - if ev.speech_duration >= self._opts.int_speech_duration: - self._interrupt_if_possible() + if ev.speech_duration >= self._opts.int_speech_duration: + self._interrupt_if_possible() - if ev.raw_accumulated_speech > 0.0: - self._last_speech_time = ( - time.perf_counter() - ev.raw_accumulated_silence - ) + if ev.raw_accumulated_speech > 0.0: + self._last_speech_time = ( + time.perf_counter() - ev.raw_accumulated_silence + ) - def _on_end_of_speech(ev: vad.VADEvent) -> None: - self._plotter.plot_event("user_stopped_speaking") - self.emit("user_stopped_speaking") - self._deferred_validation.on_human_end_of_speech(ev) + def _on_end_of_speech(self, ev: vad.VADEvent) -> None: + self._plotter.plot_event("user_stopped_speaking") + self.emit("user_stopped_speaking") + self._deferred_validation.on_human_end_of_speech(ev) - def _on_interim_transcript(ev: stt.SpeechEvent) -> None: - self._transcribed_interim_text = ev.alternatives[0].text + def _on_interim_transcript(self, ev: stt.SpeechEvent) -> None: + self._transcribed_interim_text = ev.alternatives[0].text - def _on_final_transcript(ev: stt.SpeechEvent) -> None: - new_transcript = ev.alternatives[0].text - if not new_transcript: - return + def _on_final_transcript(self, ev: stt.SpeechEvent) -> None: + new_transcript = ev.alternatives[0].text + if not new_transcript: + return - logger.debug( - "received user transcript", - extra={"user_transcript": new_transcript}, - ) + logger.debug( + "received user transcript", + extra={"user_transcript": new_transcript}, + ) - self._last_final_transcript_time = time.perf_counter() + self._last_final_transcript_time = time.perf_counter() - self._transcribed_text += ( - " " if self._transcribed_text else "" - ) + new_transcript + self._transcribed_text += ( + " " if self._transcribed_text else "" + ) + new_transcript - if self._opts.preemptive_synthesis: - if ( - self._playing_speech is None - or self._playing_speech.allow_interruptions - ): - self._synthesize_agent_reply() + if self._opts.preemptive_synthesis: + if ( + self._playing_speech is None + or self._playing_speech.allow_interruptions + ): + self._synthesize_agent_reply() - self._deferred_validation.on_human_final_transcript(new_transcript) + self._deferred_validation.on_human_final_transcript(new_transcript) - words = self._opts.transcription.word_tokenizer.tokenize( - text=new_transcript - ) - if len(words) >= 3: - # VAD can sometimes not detect that the human is speaking - # to make the interruption more reliable, we also interrupt on the final transcript. - self._interrupt_if_possible() - - self._human_input.on("start_of_speech", _on_start_of_speech) - self._human_input.on("vad_inference_done", _on_vad_inference_done) - self._human_input.on("end_of_speech", _on_end_of_speech) - self._human_input.on("interim_transcript", _on_interim_transcript) - self._human_input.on("final_transcript", _on_final_transcript) + words = self._opts.transcription.word_tokenizer.tokenize( + text=new_transcript + ) + if len(words) >= 3: + # VAD can sometimes not detect that the human is speaking + # to make the interruption more reliable, we also interrupt on the final transcript. + self._interrupt_if_possible() + + def _link_participant(self, identity: str) -> None: + participant = self._room.remote_participants.get(identity) + if participant is None: + logger.error("_link_participant must be called with a valid identity") + return + + self._human_input = HumanInput( + room=self._room, + vad=self._vad, + stt=self._stt, + participant=participant, + transcription=self._opts.transcription.user_transcription, + ) + + # Register event handlers using class methods + self._human_input.on("start_of_speech", self._on_start_of_speech) + self._human_input.on("vad_inference_done", self._on_vad_inference_done) + self._human_input.on("end_of_speech", self._on_end_of_speech) + self._human_input.on("interim_transcript", self._on_interim_transcript) + self._human_input.on("final_transcript", self._on_final_transcript) + + def _unlink_participant(self) -> None: + """Clean up participant-related resources""" + if self._human_input is None: + return + + # Remove all event listeners using class methods + self._human_input.off("start_of_speech", self._on_start_of_speech) + self._human_input.off("vad_inference_done", self._on_vad_inference_done) + self._human_input.off("end_of_speech", self._on_end_of_speech) + self._human_input.off("interim_transcript", self._on_interim_transcript) + self._human_input.off("final_transcript", self._on_final_transcript) + self._human_input = None @utils.log_exceptions(logger=logger) async def _main_task(self) -> None: From 1f8eee181972b02bcdd20f614feb2f7695539cfb Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Fri, 22 Nov 2024 13:28:36 +1100 Subject: [PATCH 2/5] Adding changeset --- .changeset/smart-files-fly.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/smart-files-fly.md diff --git a/.changeset/smart-files-fly.md b/.changeset/smart-files-fly.md new file mode 100644 index 000000000..91e85ba1e --- /dev/null +++ b/.changeset/smart-files-fly.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": minor +--- + +Update pipeline agent to unlink human input on participant disconnect From bdc2bf03e9e1d8ae591d3bf841e8efde4e28c9e0 Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Fri, 22 Nov 2024 13:29:41 +1100 Subject: [PATCH 3/5] Reformat methods moved to class level --- .../livekit/agents/pipeline/pipeline_agent.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 97dada718..b0143994c 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -488,9 +488,7 @@ def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: self._interrupt_if_possible() if ev.raw_accumulated_speech > 0.0: - self._last_speech_time = ( - time.perf_counter() - ev.raw_accumulated_silence - ) + self._last_speech_time = time.perf_counter() - ev.raw_accumulated_silence def _on_end_of_speech(self, ev: vad.VADEvent) -> None: self._plotter.plot_event("user_stopped_speaking") @@ -517,17 +515,12 @@ def _on_final_transcript(self, ev: stt.SpeechEvent) -> None: ) + new_transcript if self._opts.preemptive_synthesis: - if ( - self._playing_speech is None - or self._playing_speech.allow_interruptions - ): + if self._playing_speech is None or self._playing_speech.allow_interruptions: self._synthesize_agent_reply() self._deferred_validation.on_human_final_transcript(new_transcript) - words = self._opts.transcription.word_tokenizer.tokenize( - text=new_transcript - ) + words = self._opts.transcription.word_tokenizer.tokenize(text=new_transcript) if len(words) >= 3: # VAD can sometimes not detect that the human is speaking # to make the interruption more reliable, we also interrupt on the final transcript. From db84a85aea53fcb5e4a99a9e71bdeec243ea9d4e Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Fri, 22 Nov 2024 14:16:44 +1100 Subject: [PATCH 4/5] Making typing explicit for dataclasses --- .../livekit/agents/pipeline/pipeline_agent.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index b0143994c..db992331b 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -18,6 +18,8 @@ from livekit import rtc from .. import metrics, stt, tokenize, tts, utils, vad +from ..stt import SpeechEvent +from ..vad import VADEvent from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream from ..types import ATTRIBUTE_AGENT_STATE, AgentState from .agent_output import AgentOutput, SpeechSource, SynthesisHandle @@ -462,12 +464,12 @@ def _on_participant_disconnected(self, participant: rtc.RemoteParticipant): self._unlink_participant() - def _on_start_of_speech(self, ev: vad.VADEvent) -> None: + def _on_start_of_speech(self, ev: VADEvent) -> None: self._plotter.plot_event("user_started_speaking") self.emit("user_started_speaking") self._deferred_validation.on_human_start_of_speech(ev) - def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: + def _on_vad_inference_done(self, ev: VADEvent) -> None: if not self._track_published_fut.done(): return @@ -490,15 +492,15 @@ def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: if ev.raw_accumulated_speech > 0.0: self._last_speech_time = time.perf_counter() - ev.raw_accumulated_silence - def _on_end_of_speech(self, ev: vad.VADEvent) -> None: + def _on_end_of_speech(self, ev: VADEvent) -> None: self._plotter.plot_event("user_stopped_speaking") self.emit("user_stopped_speaking") self._deferred_validation.on_human_end_of_speech(ev) - def _on_interim_transcript(self, ev: stt.SpeechEvent) -> None: + def _on_interim_transcript(self, ev: SpeechEvent) -> None: self._transcribed_interim_text = ev.alternatives[0].text - def _on_final_transcript(self, ev: stt.SpeechEvent) -> None: + def _on_final_transcript(self, ev: SpeechEvent) -> None: new_transcript = ev.alternatives[0].text if not new_transcript: return @@ -1061,13 +1063,13 @@ def on_human_final_transcript(self, transcript: str) -> None: self._run(delay) - def on_human_start_of_speech(self, ev: vad.VADEvent) -> None: + def on_human_start_of_speech(self, ev: VADEvent) -> None: self._speaking = True if self.validating: assert self._validating_task is not None self._validating_task.cancel() - def on_human_end_of_speech(self, ev: vad.VADEvent) -> None: + def on_human_end_of_speech(self, ev: VADEvent) -> None: self._speaking = False self._last_recv_end_of_speech_time = time.time() From fcca1676f574b9c77d10304914f823a343e3260f Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Fri, 22 Nov 2024 14:21:16 +1100 Subject: [PATCH 5/5] Sort import blocks, remove redundant comment --- livekit-agents/livekit/agents/pipeline/pipeline_agent.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index db992331b..5ec3b4456 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -18,10 +18,10 @@ from livekit import rtc from .. import metrics, stt, tokenize, tts, utils, vad -from ..stt import SpeechEvent -from ..vad import VADEvent from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream +from ..stt import SpeechEvent from ..types import ATTRIBUTE_AGENT_STATE, AgentState +from ..vad import VADEvent from .agent_output import AgentOutput, SpeechSource, SynthesisHandle from .agent_playout import AgentPlayout from .human_input import HumanInput @@ -550,7 +550,6 @@ def _link_participant(self, identity: str) -> None: self._human_input.on("final_transcript", self._on_final_transcript) def _unlink_participant(self) -> None: - """Clean up participant-related resources""" if self._human_input is None: return