Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stt: reduce bandwidth usage by reducing sample_rate to 16khz #920

Merged
merged 6 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/friendly-cycles-double.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-agents": patch
"livekit-plugins-deepgram": patch
---

stt: reduce bandwidth usage by reducing sample_rate to 16khz
4 changes: 3 additions & 1 deletion livekit-agents/livekit/agents/pipeline/human_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None:
self._recognize_atask.cancel()

self._recognize_atask = asyncio.create_task(
self._recognize_task(rtc.AudioStream(self._subscribed_track)) # type: ignore
self._recognize_task(
rtc.AudioStream(self._subscribed_track, sample_rate=16000)
) # type: ignore
)
break

Expand Down
41 changes: 38 additions & 3 deletions livekit-agents/livekit/agents/stt/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,25 +76,60 @@ class SpeechStream(ABC):
class _FlushSentinel:
pass

def __init__(self):
def __init__(self, *, sample_rate: int | None = None):
"""
Args:
sample_rate : int or None, optional
The desired sample rate for the audio input.
If specified, the audio input will be automatically resampled to match
the given sample rate before being processed for Speech-to-Text.
If not provided (None), the input will retain its original sample rate.
"""
self._input_ch = aio.Chan[Union[rtc.AudioFrame, SpeechStream._FlushSentinel]]()
self._event_ch = aio.Chan[SpeechEvent]()
self._task = asyncio.create_task(self._main_task())
self._task.add_done_callback(lambda _: self._event_ch.close())

self._needed_sr = sample_rate
self._pushed_sr = 0
self._resampler: rtc.AudioResampler | None = None

@abstractmethod
def _main_task(self) -> None: ...
async def _main_task(self) -> None: ...

def push_frame(self, frame: rtc.AudioFrame) -> None:
"""Push audio to be recognized"""
self._check_input_not_ended()
self._check_not_closed()
self._input_ch.send_nowait(frame)

if self._pushed_sr and self._pushed_sr != frame.sample_rate:
raise ValueError("the sample rate of the input frames must be consistent")

self._pushed_sr = frame.sample_rate

if self._needed_sr and self._needed_sr != frame.sample_rate:
if not self._resampler:
self._resampler = rtc.AudioResampler(
frame.sample_rate,
self._needed_sr,
quality=rtc.AudioResamplerQuality.HIGH,
)

if self._resampler:
for frame in self._resampler.push(frame):
self._input_ch.send_nowait(frame)
else:
self._input_ch.send_nowait(frame)

def flush(self) -> None:
"""Mark the end of the current segment"""
self._check_input_not_ended()
self._check_not_closed()

if self._resampler:
for frame in self._resampler.flush():
self._input_ch.send_nowait(frame)

self._input_ch.send_nowait(self._FlushSentinel())

def end_input(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from dataclasses import dataclass
from typing import Literal

import azure.cognitiveservices.speech as speechsdk # type: ignore
from livekit.agents import tts, utils

import azure.cognitiveservices.speech as speechsdk # type: ignore

AZURE_SAMPLE_RATE: int = 16000
AZURE_BITS_PER_SAMPLE: int = 16
AZURE_NUM_CHANNELS: int = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
interim_results: bool = True,
punctuate: bool = True,
smart_format: bool = True,
sample_rate: int = 16000,
no_delay: bool = True,
endpointing_ms: int = 25,
filler_words: bool = False,
Expand Down Expand Up @@ -116,7 +117,7 @@ def __init__(
no_delay=no_delay,
endpointing_ms=endpointing_ms,
filler_words=filler_words,
sample_rate=48000,
sample_rate=sample_rate,
num_channels=1,
keywords=keywords,
profanity_filter=profanity_filter,
Expand Down Expand Up @@ -195,7 +196,7 @@ def __init__(
http_session: aiohttp.ClientSession,
max_retry: int = 32,
) -> None:
super().__init__()
super().__init__(sample_rate=opts.sample_rate)

if opts.detect_language and opts.language is None:
raise ValueError("language detection is not supported in streaming mode")
Expand Down
Loading