Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a colab notebook demo, and a separate script. Support from_pretrained for easier use. #668

Merged
merged 15 commits into from
Jun 1, 2023
37 changes: 24 additions & 13 deletions egs/wsj0-mix-var/Multi-Decoder-DPRNN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,39 @@

**Abstract**: We propose an end-to-end trainable approach to single-channel speech separation with unknown number of speakers, **only training a single model for arbitrary number of speakers**. Our approach extends the MulCat source separation backbone with additional output heads: a count-head to infer the number of speakers, and decoder-heads for reconstructing the original signals. Beyond the model, we also propose a metric on how to evaluate source separation with variable number of speakers. Specifically, we cleared up the issue on how to evaluate the quality when the ground-truth hasmore or less speakers than the ones predicted by the model. We evaluate our approach on the WSJ0-mix datasets, with mixtures up to five speakers. **We demonstrate that our approach outperforms state-of-the-art in counting the number of speakers and remains competitive in quality of reconstructed signals.**

paper arxiv link: https://arxiv.org/abs/2011.12022
paper link: http://www.isle.illinois.edu/speech_web_lg/pubs/2021/zhu2021multi.pdf

## Project Page & Demo
## Project Page & Examples
Project page & example output can be found [here](https://junzhejosephzhu.github.io/Multi-Decoder-DPRNN/)

## Getting Started
Install asteroid by running ```pip install -e .``` in asteroid directory
To install the requirements, run ```pip install -r requirements.txt```

To run a pre-trained model on your own .wav mixture files, run ```python eval.py --wav_file {file_name.wav} --use_gpu {1/0}```. The script should automatically download a pre-trained model(link below).

You can use regular expressions for file names. For example, you can run ```python eval.py --wav_file local/*.wav --use_gpu 0 ```

The default output directory will be ./output, but you can override that with ```--output_dir``` option

If you want to download an alternative pre-trained model, you can create a folder, and save the pretrained model in ```{folder_name}/checkpoints/best-model.ckpt```, then run ```python eval.py --wav_file {file_name.wav} --use_gpu {1/0} --exp_dir {folder_name}```
### Colab notebooks:
* Usage Example: [![Usage Example](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/11MGx3_sgOrQrB6k8edyAvg5mGIxqR5ED?usp=sharing)
### Run locally
To Setup, Run the following commands:
```
git clone https://github.com/asteroid-team/asteroid.git
cd asteroid/egs/wsj0-mix-var/Multi-Decoder-DPRNN
pip install -r requirements.txt
```
To run separation on a wav file, run:
```
python separate.py --wav_file ${mixture_file}
```
To load the model, run:
```
from model import MultiDecoderDPRNN
model = MultiDecoderDPRNN.from_pretrained("JunzheJosephZhu/MultiDecoderDPRNN").eval()
model.separate(input_tensor)
```

## Train your own model
To train the model, edit the file paths in run.sh and execute ```./run.sh --stage 0```, follow the instructions to generate dataset and train the model.

After training the model, execute ```./run.sh --stage 4``` to evaluate the model. Some examples will be saved in exp/tmp_uuid/examples

Alternatively, the training script and evaluation script can be found at train.py and eval.py

## Kindly cite this paper
```
@INPROCEEDINGS{9414205,
Expand All @@ -39,7 +50,7 @@ After training the model, execute ```./run.sh --stage 4``` to evaluate the model
```

## Resources
Pretrained mini model and config can be found at: https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN \
Pretrained mini model and config can be found at: https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN

This is the refactored version of the code, with some hyperparameter changes. If you want to reproduce the paper results, original experiment code & config can be found at https://github.com/JunzheJosephZhu/MultiDecoder-DPRNN

Expand Down
33 changes: 4 additions & 29 deletions egs/wsj0-mix-var/Multi-Decoder-DPRNN/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
Author: Joseph(Junzhe) Zhu, 2021/5. Email: josefzhu@stanford.edu / junzhe.joseph.zhu@gmail.com
For the original code for the paper[1], please refer to https://github.com/JunzheJosephZhu/MultiDecoder-DPRNN
Demo Page: https://junzhejosephzhu.github.io/Multi-Decoder-DPRNN/
Multi-Decoder DPRNN is a method for source separation when the number of speakers is unknown.
Our contribution is using multiple output heads, with each head modelling a distinct number of source outputs.
In addition, we design a selector network which determines which output head to use, i.e. estimates the number of sources.
The "DPRNN" part of the architecture is orthogonal to our contribution, and can be replaced with any other separator, e.g. Conv/LSTM-TasNet.
Multi-Decoder DPRNN is a method for source separation when the number of speakers is unknown.
Our contribution is using multiple output heads, with each head modelling a distinct number of source outputs.
In addition, we design a selector network which determines which output head to use, i.e. estimates the number of sources.
The "DPRNN" part of the architecture is orthogonal to our contribution, and can be replaced with any other separator, e.g. Conv/LSTM-TasNet.
References:
[1] "Multi-Decoder DPRNN: High Accuracy Source Counting and Separation",
Junzhe Zhu, Raymond Yeh, Mark Hasegawa-Johnson. https://arxiv.org/abs/2011.12022
Expand Down Expand Up @@ -38,12 +38,6 @@
type=str,
help="One of `enh_single`, `enh_both`, " "`sep_clean` or `sep_noisy`",
)
parser.add_argument(
"--wav_file",
type=str,
default="",
help="Path to the wav file to run model inference on. Could be a regular expression of {folder_name}/*.wav",
)
parser.add_argument(
"--output_dir", type=str, default="output", help="Output folder for inference results"
)
Expand Down Expand Up @@ -85,25 +79,6 @@ def main(conf):
test_dirs = [
conf["test_dir"].format(n_src) for n_src in conf["train_conf"]["masknet"]["n_srcs"]
]
if conf["wav_file"]:
mix_files = glob.glob(conf["wav_file"])
if not os.path.exists(conf["output_dir"]):
os.makedirs(conf["output_dir"])
for mix_file in mix_files:
mix, _ = librosa.load(mix_file, sr=conf["sample_rate"])
mix = tensors_to_device(torch.Tensor(mix), device=model_device)
est_sources = model.separate(mix[None])
est_sources = est_sources.cpu().numpy()
for i, est_src in enumerate(est_sources):
sf.write(
os.path.join(
conf["output_dir"],
os.path.basename(mix_file).replace(".wav", f"_spkr{i}.wav"),
),
est_src,
conf["sample_rate"],
)

# evaluate metrics
if conf["test_dir"]:
test_set = Wsj0mixVariable(
Expand Down
10 changes: 5 additions & 5 deletions egs/wsj0-mix-var/Multi-Decoder-DPRNN/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
Author: Joseph(Junzhe) Zhu, 2021/5. Email: josefzhu@stanford.edu / junzhe.joseph.zhu@gmail.com
For the original code for the paper[1], please refer to https://github.com/JunzheJosephZhu/MultiDecoder-DPRNN
Demo Page: https://junzhejosephzhu.github.io/Multi-Decoder-DPRNN/
Multi-Decoder DPRNN is a method for source separation when the number of speakers is unknown.
Our contribution is using multiple output heads, with each head modelling a distinct number of source outputs.
In addition, we design a selector network which determines which output head to use, i.e. estimates the number of sources.
The "DPRNN" part of the architecture is orthogonal to our contribution, and can be replaced with any other separator, e.g. Conv/LSTM-TasNet.
Multi-Decoder DPRNN is a method for source separation when the number of speakers is unknown.
Our contribution is using multiple output heads, with each head modelling a distinct number of source outputs.
In addition, we design a selector network which determines which output head to use, i.e. estimates the number of sources.
The "DPRNN" part of the architecture is orthogonal to our contribution, and can be replaced with any other separator, e.g. Conv/LSTM-TasNet.
References:
[1] "Multi-Decoder DPRNN: High Accuracy Source Counting and Separation",
Junzhe Zhu, Raymond Yeh, Mark Hasegawa-Johnson. https://arxiv.org/abs/2011.12022
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
"""
Args:
wav: 2D or 3D Tensor, Tensor of shape $(batch, T)$
ground_truth: oracle number of speakers, None or list of $(batch)$ ints
ground_truth: oracle number of speakers, None or list of $(batch)$ ints
Return:
reconstructed: torch.Tensor, $(batch, num_stages, max_spks, T)$
where max_spks is the maximum possible number of speakers.
Expand Down
28 changes: 28 additions & 0 deletions egs/wsj0-mix-var/Multi-Decoder-DPRNN/separate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch, torchaudio
import argparse
import os
from model import MultiDecoderDPRNN

os.makedirs("outputs", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument(
"--wav_file",
type=str,
default="",
help="Path to the wav file to run model inference on.",
)
args = parser.parse_args()

mixture, sample_rate = torchaudio.load(args.wav_file)

model = MultiDecoderDPRNN.from_pretrained("JunzheJosephZhu/MultiDecoderDPRNN").eval()
if torch.cuda.is_available():
model.cuda()
mixture = mixture.cuda()
sources_est = model.separate(mixture).cpu()
for i, source in enumerate(sources_est):
torchaudio.save(f"outputs/{i}.wav", source[None], sample_rate)

print(
"Thank you for using Multi-Decoder-DPRNN to separate your mixture files. Please support our work by citing our paper: http://www.isle.illinois.edu/speech_web_lg/pubs/2021/zhu2021multi.pdf"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line looks too long.
Can you cut the print in half.

Otherwise, the rest is good.

)