Skip to content

Commit

Permalink
add silero.VAD.load (#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Jul 24, 2024
1 parent f953d47 commit ef76aa8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
4 changes: 2 additions & 2 deletions livekit-agents/livekit/agents/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class VADCapabilities:


class VAD(ABC):
def __init__(self, *, capatiilities: VADCapabilities) -> None:
self._capabilities = capatiilities
def __init__(self, *, capabilities: VADCapabilities) -> None:
self._capabilities = capabilities

@property
def capabilities(self) -> VADCapabilities:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dataclasses import dataclass

import numpy as np
import onnxruntime # type: ignore
from livekit import agents, rtc
from livekit.agents import utils

Expand All @@ -37,8 +38,9 @@ class _VADOptions:


class VAD(agents.vad.VAD):
def __init__(
self,
@classmethod
def load(
cls,
*,
min_speech_duration: float = 0.05,
min_silence_duration: float = 0.1,
Expand All @@ -47,7 +49,7 @@ def __init__(
activation_threshold: float = 0.25,
sample_rate: int = 16000,
force_cpu: bool = True,
) -> None:
) -> "VAD":
"""
Initialize the Silero VAD with the given options.
The options are already set to strong defaults.
Expand All @@ -61,19 +63,29 @@ def __init__(
sample_rate: sample rate for the inference (only 8KHz and 16KHz are supported)
force_cpu: force to use CPU for inference
"""

if sample_rate not in onnx_model.SUPPORTED_SAMPLE_RATES:
raise ValueError("Silero VAD only supports 8KHz and 16KHz sample rates")

self._onnx_session = onnx_model.new_inference_session(force_cpu)
self._opts = _VADOptions(
session = onnx_model.new_inference_session(force_cpu)
opts = _VADOptions(
min_speech_duration=min_speech_duration,
min_silence_duration=min_silence_duration,
padding_duration=padding_duration,
max_buffered_speech=max_buffered_speech,
activation_threshold=activation_threshold,
sample_rate=sample_rate,
)
return cls(session=session, opts=opts)

def __init__(
self,
*,
session: onnxruntime.InferenceSession,
opts: _VADOptions,
) -> None:
super().__init__(capabilities=agents.vad.VADCapabilities(update_interval=0.032))
self._onnx_session = session
self._opts = opts

def stream(self) -> "VADStream":
return VADStream(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def test_recognize(stt: agents.stt.STT):
assert event.type == agents.stt.SpeechEventType.FINAL_TRANSCRIPT


STREAM_VAD = silero.VAD()
STREAM_VAD = silero.VAD.load()
STREAM_STT = [
deepgram.STT(),
google.STT(),
Expand Down

0 comments on commit ef76aa8

Please sign in to comment.