Skip to content

Commit

Permalink
silero: add update_options (#899)
Browse files Browse the repository at this point in the history
Co-authored-by: Long Chen <longch1024@gmail.com>
  • Loading branch information
theomonnom and longcw authored Oct 14, 2024
1 parent 3eeb5dc commit cc2fe77
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 38 deletions.
5 changes: 5 additions & 0 deletions .changeset/shy-ghosts-greet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-silero": patch
---

silero: add update_options
171 changes: 133 additions & 38 deletions livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,64 @@ def __init__(
self._onnx_session = session
self._opts = opts

self._streams: list[VADStream] = []

def stream(self) -> "VADStream":
"""
Create a new VADStream for processing audio data.
Returns:
VADStream: A stream object for processing audio input and detecting speech.
"""
return VADStream(
stream = VADStream(
self._opts,
onnx_model.OnnxModel(
onnx_session=self._onnx_session, sample_rate=self._opts.sample_rate
),
)
self._streams.append(stream)
return stream

def update_options(
self,
*,
min_speech_duration: float | None = None,
min_silence_duration: float | None = None,
prefix_padding_duration: float | None = None,
max_buffered_speech: float | None = None,
activation_threshold: float | None = None,
) -> None:
"""
Update the VAD options.
This method allows you to update the VAD options after the VAD object has been created.
Args:
min_speech_duration (float): Minimum duration of speech to start a new speech chunk.
min_silence_duration (float): At the end of each speech, wait this duration before ending the speech.
prefix_padding_duration (float): Duration of padding to add to the beginning of each speech chunk.
max_buffered_speech (float): Maximum duration of speech to keep in the buffer (in seconds).
activation_threshold (float): Threshold to consider a frame as speech.
"""
self._opts = _VADOptions(
min_speech_duration=min_speech_duration or self._opts.min_speech_duration,
min_silence_duration=min_silence_duration
or self._opts.min_silence_duration,
prefix_padding_duration=prefix_padding_duration
or self._opts.prefix_padding_duration,
max_buffered_speech=max_buffered_speech or self._opts.max_buffered_speech,
activation_threshold=activation_threshold
or self._opts.activation_threshold,
sample_rate=self._opts.sample_rate,
)
for stream in self._streams:
stream.update_options(
min_speech_duration=min_speech_duration,
min_silence_duration=min_silence_duration,
prefix_padding_duration=prefix_padding_duration,
max_buffered_speech=max_buffered_speech,
activation_threshold=activation_threshold,
)


class VADStream(agents.vad.VADStream):
Expand All @@ -159,15 +204,64 @@ def __init__(self, opts: _VADOptions, model: onnx_model.OnnxModel) -> None:
self._task.add_done_callback(lambda _: self._executor.shutdown(wait=False))
self._exp_filter = utils.ExpFilter(alpha=0.35)

self._extra_inference_time = 0.0
self._input_sample_rate = 0
self._speech_buffer: np.ndarray | None = None
self._speech_buffer_max_reached = False
self._prefix_padding_samples = 0 # (input_sample_rate)

def update_options(
self,
*,
min_speech_duration: float | None = None,
min_silence_duration: float | None = None,
prefix_padding_duration: float | None = None,
max_buffered_speech: float | None = None,
activation_threshold: float | None = None,
) -> None:
"""
Update the VAD options.
This method allows you to update the VAD options after the VAD object has been created.
Args:
min_speech_duration (float): Minimum duration of speech to start a new speech chunk.
min_silence_duration (float): At the end of each speech, wait this duration before ending the speech.
prefix_padding_duration (float): Duration of padding to add to the beginning of each speech chunk.
max_buffered_speech (float): Maximum duration of speech to keep in the buffer (in seconds).
activation_threshold (float): Threshold to consider a frame as speech.
"""
old_max_buffered_speech = self._opts.max_buffered_speech

self._opts = _VADOptions(
min_speech_duration=min_speech_duration or self._opts.min_speech_duration,
min_silence_duration=min_silence_duration
or self._opts.min_silence_duration,
prefix_padding_duration=prefix_padding_duration
or self._opts.prefix_padding_duration,
max_buffered_speech=max_buffered_speech or self._opts.max_buffered_speech,
activation_threshold=activation_threshold
or self._opts.activation_threshold,
sample_rate=self._opts.sample_rate,
)

if self._input_sample_rate:
assert self._speech_buffer is not None

self._prefix_padding_samples = int(
self._opts.prefix_padding_duration * self._input_sample_rate
)

self._speech_buffer.resize(
int(self._opts.max_buffered_speech * self._input_sample_rate)
+ self._prefix_padding_samples
)

if self._opts.max_buffered_speech > old_max_buffered_speech:
self._speech_buffer_max_reached = False

@agents.utils.log_exceptions(logger=logger)
async def _main_task(self):
inference_f32_data = np.empty(self._model.window_size_samples, dtype=np.float32)

# a copy is exposed to the user in END_OF_SPEECH
speech_buffer: np.ndarray | None = None
speech_buffer_max_reached = False
speech_buffer_index: int = 0

# "pub_" means public, these values are exposed to the users through events
Expand All @@ -177,9 +271,6 @@ async def _main_task(self):
pub_current_sample = 0
pub_timestamp = 0.0

pub_sample_rate = 0
pub_prefix_padding_samples = 0 # size in samples of padding data

speech_threshold_duration = 0.0
silence_threshold_duration = 0.0

Expand All @@ -190,37 +281,41 @@ async def _main_task(self):
# used to avoid drift when the sample_rate ratio is not an integer
input_copy_remaining_fract = 0.0

extra_inference_time = 0.0

async for input_frame in self._input_ch:
if not isinstance(input_frame, rtc.AudioFrame):
continue # ignore flush sentinel for now

if not pub_sample_rate or speech_buffer is None:
pub_sample_rate = input_frame.sample_rate
if not self._input_sample_rate:
self._input_sample_rate = input_frame.sample_rate

# alloc the buffers now that we know the input sample rate
pub_prefix_padding_samples = int(
self._opts.prefix_padding_duration * pub_sample_rate
self._prefix_padding_samples = int(
self._opts.prefix_padding_duration * self._input_sample_rate
)

speech_buffer = np.empty(
int(self._opts.max_buffered_speech * pub_sample_rate)
+ pub_prefix_padding_samples,
self._speech_buffer = np.empty(
int(self._opts.max_buffered_speech * self._input_sample_rate)
+ self._prefix_padding_samples,
dtype=np.int16,
)

if pub_sample_rate != self._opts.sample_rate:
if self._input_sample_rate != self._opts.sample_rate:
# resampling needed: the input sample rate isn't the same as the model's
# sample rate used for inference
resampler = rtc.AudioResampler(
input_rate=pub_sample_rate,
input_rate=self._input_sample_rate,
output_rate=self._opts.sample_rate,
quality=rtc.AudioResamplerQuality.QUICK, # VAD doesn't need high quality
)

elif pub_sample_rate != input_frame.sample_rate:
elif self._input_sample_rate != input_frame.sample_rate:
logger.error("a frame with another sample rate was already pushed")
continue

assert self._speech_buffer is not None

input_frames.append(input_frame)
if resampler is not None:
# the resampler may have a bit of latency, but it is OK to ignore since it should be
Expand Down Expand Up @@ -262,7 +357,7 @@ async def _main_task(self):
pub_current_sample += self._model.window_size_samples
pub_timestamp += window_duration

resampling_ratio = pub_sample_rate / self._model.sample_rate
resampling_ratio = self._input_sample_rate / self._model.sample_rate
to_copy = (
self._model.window_size_samples * resampling_ratio
+ input_copy_remaining_fract
Expand All @@ -271,54 +366,54 @@ async def _main_task(self):
input_copy_remaining_fract = to_copy - to_copy_int

# copy the inference window to the speech buffer
available_space = len(speech_buffer) - speech_buffer_index
available_space = len(self._speech_buffer) - speech_buffer_index
to_copy_buffer = min(to_copy_int, available_space)
if to_copy_buffer > 0:
speech_buffer[
self._speech_buffer[
speech_buffer_index : speech_buffer_index + to_copy_buffer
] = input_frame.data[:to_copy_buffer]
speech_buffer_index += to_copy_buffer
elif not speech_buffer_max_reached:
elif not self._speech_buffer_max_reached:
# reached self._opts.max_buffered_speech (padding is included)
speech_buffer_max_reached = True
logger.warning(
"max_buffered_speech reached, ignoring further data for the current speech input"
)

inference_duration = time.perf_counter() - start_time
self._extra_inference_time = max(
extra_inference_time = max(
0.0,
self._extra_inference_time + inference_duration - window_duration,
extra_inference_time + inference_duration - window_duration,
)
if inference_duration > SLOW_INFERENCE_THRESHOLD:
logger.warning(
"inference is slower than realtime",
extra={"delay": self._extra_inference_time},
extra={"delay": extra_inference_time},
)

def _reset_write_cursor():
nonlocal speech_buffer_index, speech_buffer_max_reached
assert speech_buffer is not None
assert self._speech_buffer is not None

if speech_buffer_index <= pub_prefix_padding_samples:
if speech_buffer_index <= self._prefix_padding_samples:
return

padding_data = speech_buffer[
padding_data = self._speech_buffer[
speech_buffer_index
- pub_prefix_padding_samples : speech_buffer_index
- self._prefix_padding_samples : speech_buffer_index
]

speech_buffer[:pub_prefix_padding_samples] = padding_data
speech_buffer_index = pub_prefix_padding_samples
speech_buffer_max_reached = False
self._speech_buffer_max_reached = False
self._speech_buffer[: self._prefix_padding_samples] = padding_data
speech_buffer_index = self._prefix_padding_samples

def _copy_speech_buffer() -> rtc.AudioFrame:
# copy the data from speech_buffer
assert speech_buffer is not None
speech_data = speech_buffer[:speech_buffer_index].tobytes()
assert self._speech_buffer is not None
speech_data = self._speech_buffer[:speech_buffer_index].tobytes()

return rtc.AudioFrame(
sample_rate=pub_sample_rate,
sample_rate=self._input_sample_rate,
num_channels=1,
samples_per_channel=speech_buffer_index,
data=speech_data,
Expand All @@ -341,7 +436,7 @@ def _copy_speech_buffer() -> rtc.AudioFrame:
frames=[
rtc.AudioFrame(
data=input_frame.data[:to_copy_int].tobytes(),
sample_rate=pub_sample_rate,
sample_rate=self._input_sample_rate,
num_channels=1,
samples_per_channel=to_copy_int,
)
Expand Down Expand Up @@ -412,7 +507,7 @@ def _copy_speech_buffer() -> rtc.AudioFrame:
input_frames.append(
rtc.AudioFrame(
data=data,
sample_rate=pub_sample_rate,
sample_rate=self._input_sample_rate,
num_channels=1,
samples_per_channel=len(data) // 2,
)
Expand Down

0 comments on commit cc2fe77

Please sign in to comment.