Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for merge] Diarization workflow with SpeechBrain #1031

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions lhotse/bin/modes/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,55 @@ def align_with_torchaudio(
writer.write(cut, flush=True)


@workflows.command()
@click.argument(
"in_cuts", type=click.Path(exists=True, dir_okay=False, allow_dash=True)
)
@click.argument("out_cuts", type=click.Path(allow_dash=True))
@click.option(
"-d", "--device", default="cpu", help="Device on which to run the inference."
)
@click.option(
"--num-speakers",
type=int,
default=None,
help="Number of clusters to use for speaker diarization. Will use threshold if not provided.",
)
@click.option(
"--threshold",
type=float,
default=None,
help="Threshold for speaker diarization. Will use num-speakers if not provided.",
)
def diarize_segments_with_speechbrain(
in_cuts: str,
out_cuts: str,
device: str = "cpu",
num_speakers: Optional[int] = None,
threshold: Optional[float] = None,
):
"""
This workflow uses SpeechBrain's pretrained speaker embedding model to compute speaker embeddings
for each cut in the CutSet. The cuts for the same recording are then clustered using
agglomerative hierarchical clustering, and the resulting cluster indices are used to create new cuts
with the speaker labels.

Please refer to https://huggingface.co/speechbrain/spkrec-xvect-voxceleb for more details
about the speaker embedding extractor.
"""
from lhotse.workflows import diarize_segments_with_speechbrain

assert exactly_one_not_null(
num_speakers, threshold
), "Exactly one of --num-speakers and --threshold must be provided."

cuts = load_manifest_lazy_or_eager(in_cuts)
cuts_with_spk_id = diarize_segments_with_speechbrain(
cuts, device=device, num_speakers=num_speakers, threshold=threshold
)
cuts_with_spk_id.to_file(out_cuts)


@workflows.command()
@click.argument(
"in_cuts", type=click.Path(exists=True, dir_okay=False, allow_dash=True)
Expand Down Expand Up @@ -203,8 +252,6 @@ def align_with_torchaudio(
show_default=True,
)
# Options used with the "conversational" method


@click.option(
"--same-spk-pause",
type=float,
Expand Down
1 change: 1 addition & 0 deletions lhotse/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .diarization import diarize_segments_with_speechbrain
from .forced_alignment import align_with_torchaudio
from .meeting_simulation import *
from .whisper import annotate_with_whisper
117 changes: 117 additions & 0 deletions lhotse/workflows/diarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
import shutil
import tempfile

import numpy as np
import torch
from attr import frozen
from cytoolz.itertoolz import groupby
from tqdm import tqdm

from lhotse import CutSet, Recording
from lhotse.utils import fastcopy, is_module_available

logging.basicConfig(
format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)


def diarize_segments_with_speechbrain(
cuts: CutSet,
device: str = "cpu",
num_speakers: int = None,
threshold: float = 0.5,
) -> CutSet:
"""
This workflow uses SpeechBrain's pretrained speaker embedding model to compute speaker embeddings
for each cut in the CutSet. The cuts for the same recording are then clustered using
agglomerative hierarchical clustering, and the resulting cluster indices are used to create new cuts
with the speaker labels.

Please refer to https://huggingface.co/speechbrain/spkrec-xvect-voxceleb for more details
about the speaker embedding extractor.

:param manifest: a ``CutSet`` object.
:param device: Where to run the inference (cpu, cuda, etc.).
:param num_speakers: Number of speakers to cluster the cuts into. If not specified, we will use
the threshold parameter to determine the number of speakers.
:param threshold: The threshold for agglomerative clustering.
:return: a new ``CutSet`` with speaker labels.
"""
assert is_module_available("speechbrain"), (
"This function expects SpeechBrain to be installed. "
"You can install it via 'pip install speechbrain' "
)

assert is_module_available("sklearn"), (
"This function expects scikit-learn to be installed. "
"You can install it via 'pip install scikit-learn' "
)

from sklearn.cluster import AgglomerativeClustering
from speechbrain.pretrained import EncoderClassifier

threshold = None if num_speakers is not None else threshold
dirpath = tempfile.mkdtemp()

recordings, _, _ = cuts.decompose(dirpath, verbose=True)
recordings = recordings.to_eager()
recording_ids = frozenset(recordings.ids)

logging.info("Saving cut recordings temporarily to disk...")
cuts_ = []
for cut in tqdm(cuts):
save_path = f"{dirpath}/{cut.recording_id}.wav"
_ = cut.save_audio(save_path)
cuts_.append(fastcopy(cut, recording=Recording.from_file(save_path)))

cuts_ = CutSet.from_cuts(cuts_).trim_to_supervisions(keep_overlapping=False)

# Load the pretrained model
model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-xvect-voxceleb",
savedir="pretrained_models/spkrec-xvect-voxceleb",
run_opts={"device": device},
)

out_cuts = []

for recording_id in tqdm(recording_ids, total=len(recording_ids)):
logging.info(f"Processing recording {recording_id}...")
embeddings = []
reco_cuts = cuts_.filter(lambda c: c.recording_id == recording_id)
num_cuts = len(frozenset(reco_cuts.ids))
if num_cuts == 0:
continue
for cut in tqdm(reco_cuts, total=num_cuts):
audio = torch.from_numpy(cut.load_audio())
embedding = model.encode_batch(audio).cpu().numpy()
embeddings.append(embedding.squeeze())

embeddings = np.vstack(embeddings)
clusterer = AgglomerativeClustering(
n_clusters=num_speakers,
affinity="euclidean",
linkage="ward",
distance_threshold=threshold,
)
clusterer.fit(embeddings)

# Assign the cluster indices to the cuts
for cut, cluster_idx in zip(reco_cuts, clusterer.labels_):
sup = fastcopy(cut.supervisions[0], speaker=f"spk{cluster_idx}")
out_cuts.append(
fastcopy(
cut,
recording=recordings[cut.recording_id],
supervisions=[sup],
)
)

# Remove the temporary directory
shutil.rmtree(dirpath)

return CutSet.from_cuts(out_cuts)
5 changes: 3 additions & 2 deletions lhotse/workflows/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Generator, List, Optional, Union

import torch
from tqdm import tqdm

from lhotse import (
CutSet,
Expand Down Expand Up @@ -87,7 +88,7 @@ def _annotate_recordings(

model = whisper.load_model(model_name, device=device, download_root=download_root)

for recording in recordings:
for recording in tqdm(recordings):
if recording.num_channels > 1:
logging.warning(
f"Skipping recording '{recording.id}'. It has {recording.num_channels} channels, "
Expand Down Expand Up @@ -141,7 +142,7 @@ def _annotate_cuts(

model = whisper.load_model(model_name, device=device, download_root=download_root)

for cut in cuts:
for cut in tqdm(cuts):
if cut.num_channels > 1:
logging.warning(
f"Skipping cut '{cut.id}'. It has {cut.num_channels} channels, "
Expand Down