Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

silero: add update_options #899

Merged
merged 13 commits into from
Oct 14, 2024
5 changes: 5 additions & 0 deletions .changeset/shy-ghosts-greet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-silero": patch
---

silero: add update_options
5 changes: 5 additions & 0 deletions .changeset/tricky-parrots-notice.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-silero": patch
---

silero: fix speech_buffer for END_OF_SPEECH
178 changes: 136 additions & 42 deletions livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations, print_function

import asyncio
import math
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
Expand Down Expand Up @@ -54,8 +53,8 @@ def load(
cls,
*,
min_speech_duration: float = 0.05,
min_silence_duration: float = 0.25,
prefix_padding_duration: float = 0.1,
min_silence_duration: float = 0.55,
prefix_padding_duration: float = 0.5,
max_buffered_speech: float = 60.0,
activation_threshold: float = 0.5,
sample_rate: Literal[8000, 16000] = 16000,
Expand Down Expand Up @@ -135,19 +134,64 @@ def __init__(
self._onnx_session = session
self._opts = opts

self._streams: list[VADStream] = []

def stream(self) -> "VADStream":
"""
Create a new VADStream for processing audio data.

Returns:
VADStream: A stream object for processing audio input and detecting speech.
"""
return VADStream(
stream = VADStream(
self._opts,
onnx_model.OnnxModel(
onnx_session=self._onnx_session, sample_rate=self._opts.sample_rate
),
)
self._streams.append(stream)
return stream

def update_options(
self,
*,
min_speech_duration: float | None = None,
min_silence_duration: float | None = None,
prefix_padding_duration: float | None = None,
max_buffered_speech: float | None = None,
activation_threshold: float | None = None,
) -> None:
"""
Update the VAD options.

This method allows you to update the VAD options after the VAD object has been created.

Args:
min_speech_duration (float): Minimum duration of speech to start a new speech chunk.
min_silence_duration (float): At the end of each speech, wait this duration before ending the speech.
prefix_padding_duration (float): Duration of padding to add to the beginning of each speech chunk.
max_buffered_speech (float): Maximum duration of speech to keep in the buffer (in seconds).
activation_threshold (float): Threshold to consider a frame as speech.
"""
self._opts = _VADOptions(
min_speech_duration=min_speech_duration or self._opts.min_speech_duration,
min_silence_duration=min_silence_duration
or self._opts.min_silence_duration,
prefix_padding_duration=prefix_padding_duration
or self._opts.prefix_padding_duration,
max_buffered_speech=max_buffered_speech or self._opts.max_buffered_speech,
activation_threshold=activation_threshold
or self._opts.activation_threshold,
sample_rate=self._opts.sample_rate,
)
for stream in self._streams:
stream.update_options(
min_speech_duration=min_speech_duration,
min_silence_duration=min_silence_duration,
prefix_padding_duration=prefix_padding_duration,
max_buffered_speech=max_buffered_speech,
activation_threshold=activation_threshold,
)


class VADStream(agents.vad.VADStream):
Expand All @@ -160,15 +204,64 @@ def __init__(self, opts: _VADOptions, model: onnx_model.OnnxModel) -> None:
self._task.add_done_callback(lambda _: self._executor.shutdown(wait=False))
self._exp_filter = utils.ExpFilter(alpha=0.35)

self._extra_inference_time = 0.0
self._input_sample_rate = 0
self._speech_buffer: np.ndarray | None = None
self._speech_buffer_max_reached = False
self._prefix_padding_samples = 0 # (input_sample_rate)

def update_options(
self,
*,
min_speech_duration: float | None = None,
min_silence_duration: float | None = None,
prefix_padding_duration: float | None = None,
max_buffered_speech: float | None = None,
activation_threshold: float | None = None,
) -> None:
"""
Update the VAD options.

This method allows you to update the VAD options after the VAD object has been created.

Args:
min_speech_duration (float): Minimum duration of speech to start a new speech chunk.
min_silence_duration (float): At the end of each speech, wait this duration before ending the speech.
prefix_padding_duration (float): Duration of padding to add to the beginning of each speech chunk.
max_buffered_speech (float): Maximum duration of speech to keep in the buffer (in seconds).
activation_threshold (float): Threshold to consider a frame as speech.
"""
old_max_buffered_speech = self._opts.max_buffered_speech

self._opts = _VADOptions(
min_speech_duration=min_speech_duration or self._opts.min_speech_duration,
min_silence_duration=min_silence_duration
or self._opts.min_silence_duration,
prefix_padding_duration=prefix_padding_duration
or self._opts.prefix_padding_duration,
max_buffered_speech=max_buffered_speech or self._opts.max_buffered_speech,
activation_threshold=activation_threshold
or self._opts.activation_threshold,
sample_rate=self._opts.sample_rate,
)

if self._input_sample_rate:
assert self._speech_buffer is not None

self._prefix_padding_samples = int(
self._opts.prefix_padding_duration * self._input_sample_rate
)

self._speech_buffer.resize(
int(self._opts.max_buffered_speech * self._input_sample_rate)
+ self._prefix_padding_samples
)

if self._opts.max_buffered_speech > old_max_buffered_speech:
self._speech_buffer_max_reached = False

@agents.utils.log_exceptions(logger=logger)
async def _main_task(self):
inference_f32_data = np.empty(self._model.window_size_samples, dtype=np.float32)

# a copy is exposed to the user in END_OF_SPEECH
speech_buffer: np.ndarray | None = None
speech_buffer_max_reached = False
speech_buffer_index: int = 0

# "pub_" means public, these values are exposed to the users through events
Expand All @@ -178,9 +271,6 @@ async def _main_task(self):
pub_current_sample = 0
pub_timestamp = 0.0

pub_sample_rate = 0
pub_prefix_padding_samples = 0 # size in samples of padding data

speech_threshold_duration = 0.0
silence_threshold_duration = 0.0

Expand All @@ -191,37 +281,41 @@ async def _main_task(self):
# used to avoid drift when the sample_rate ratio is not an integer
input_copy_remaining_fract = 0.0

extra_inference_time = 0.0

async for input_frame in self._input_ch:
if not isinstance(input_frame, rtc.AudioFrame):
continue # ignore flush sentinel for now

if not pub_sample_rate or speech_buffer is None:
pub_sample_rate = input_frame.sample_rate
if not self._input_sample_rate:
self._input_sample_rate = input_frame.sample_rate

# alloc the buffers now that we know the input sample rate
pub_prefix_padding_samples = math.ceil(
self._opts.prefix_padding_duration * pub_sample_rate
self._prefix_padding_samples = int(
self._opts.prefix_padding_duration * self._input_sample_rate
)

speech_buffer = np.empty(
int(self._opts.max_buffered_speech * pub_sample_rate)
+ int(self._opts.prefix_padding_duration * pub_sample_rate),
self._speech_buffer = np.empty(
int(self._opts.max_buffered_speech * self._input_sample_rate)
+ self._prefix_padding_samples,
dtype=np.int16,
)

if pub_sample_rate != self._opts.sample_rate:
if self._input_sample_rate != self._opts.sample_rate:
# resampling needed: the input sample rate isn't the same as the model's
# sample rate used for inference
resampler = rtc.AudioResampler(
input_rate=pub_sample_rate,
input_rate=self._input_sample_rate,
output_rate=self._opts.sample_rate,
quality=rtc.AudioResamplerQuality.QUICK, # VAD doesn't need high quality
)

elif pub_sample_rate != input_frame.sample_rate:
elif self._input_sample_rate != input_frame.sample_rate:
logger.error("a frame with another sample rate was already pushed")
continue

assert self._speech_buffer is not None

input_frames.append(input_frame)
if resampler is not None:
# the resampler may have a bit of latency, but it is OK to ignore since it should be
Expand Down Expand Up @@ -263,7 +357,7 @@ async def _main_task(self):
pub_current_sample += self._model.window_size_samples
pub_timestamp += window_duration

resampling_ratio = pub_sample_rate / self._model.sample_rate
resampling_ratio = self._input_sample_rate / self._model.sample_rate
to_copy = (
self._model.window_size_samples * resampling_ratio
+ input_copy_remaining_fract
Expand All @@ -272,54 +366,54 @@ async def _main_task(self):
input_copy_remaining_fract = to_copy - to_copy_int

# copy the inference window to the speech buffer
available_space = len(speech_buffer) - speech_buffer_index
to_copy_buffer = min(self._model.window_size_samples, available_space)
available_space = len(self._speech_buffer) - speech_buffer_index
to_copy_buffer = min(to_copy_int, available_space)
if to_copy_buffer > 0:
speech_buffer[
self._speech_buffer[
speech_buffer_index : speech_buffer_index + to_copy_buffer
] = input_frame.data[:to_copy_buffer]
speech_buffer_index += to_copy_buffer
elif not speech_buffer_max_reached:
elif not self._speech_buffer_max_reached:
# reached self._opts.max_buffered_speech (padding is included)
speech_buffer_max_reached = True
logger.warning(
"max_buffered_speech reached, ignoring further data for the current speech input"
)

inference_duration = time.perf_counter() - start_time
self._extra_inference_time = max(
extra_inference_time = max(
0.0,
self._extra_inference_time + inference_duration - window_duration,
extra_inference_time + inference_duration - window_duration,
)
if inference_duration > SLOW_INFERENCE_THRESHOLD:
logger.warning(
"inference is slower than realtime",
extra={"delay": self._extra_inference_time},
extra={"delay": extra_inference_time},
)

def _reset_write_cursor():
nonlocal speech_buffer_index, speech_buffer_max_reached
assert speech_buffer is not None
assert self._speech_buffer is not None

if speech_buffer_index <= pub_prefix_padding_samples:
if speech_buffer_index <= self._prefix_padding_samples:
return

padding_data = speech_buffer[
padding_data = self._speech_buffer[
speech_buffer_index
- pub_prefix_padding_samples : speech_buffer_index
- self._prefix_padding_samples : speech_buffer_index
]

speech_buffer[:pub_prefix_padding_samples] = padding_data
speech_buffer_index = pub_prefix_padding_samples
speech_buffer_max_reached = False
self._speech_buffer_max_reached = False
self._speech_buffer[: self._prefix_padding_samples] = padding_data
speech_buffer_index = self._prefix_padding_samples

def _copy_speech_buffer() -> rtc.AudioFrame:
# copy the data from speech_buffer
assert speech_buffer is not None
speech_data = speech_buffer[:speech_buffer_index].tobytes()
assert self._speech_buffer is not None
speech_data = self._speech_buffer[:speech_buffer_index].tobytes()

return rtc.AudioFrame(
sample_rate=pub_sample_rate,
sample_rate=self._input_sample_rate,
num_channels=1,
samples_per_channel=speech_buffer_index,
data=speech_data,
Expand All @@ -342,7 +436,7 @@ def _copy_speech_buffer() -> rtc.AudioFrame:
frames=[
rtc.AudioFrame(
data=input_frame.data[:to_copy_int].tobytes(),
sample_rate=pub_sample_rate,
sample_rate=self._input_sample_rate,
num_channels=1,
samples_per_channel=to_copy_int,
)
Expand Down Expand Up @@ -413,7 +507,7 @@ def _copy_speech_buffer() -> rtc.AudioFrame:
input_frames.append(
rtc.AudioFrame(
data=data,
sample_rate=pub_sample_rate,
sample_rate=self._input_sample_rate,
num_channels=1,
samples_per_channel=len(data) // 2,
)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from . import utils

VAD = silero.VAD.load(
min_speech_duration=0.5, min_silence_duration=0.5, padding_duration=1.0
min_speech_duration=0.5,
min_silence_duration=0.6,
)


Expand Down
Loading