Skip to content

Commit

Permalink
Merge pull request #335 from claritychallenge/evaluation-CAD1-CPC2
Browse files Browse the repository at this point in the history
Code for evaluation in cad1 cpc2
  • Loading branch information
jonbarker68 authored Jun 29, 2023
2 parents ebc7654 + e9aa196 commit a6ce4a7
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 1 deletion.
5 changes: 5 additions & 0 deletions recipes/cad1/task1/baseline/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ path:
listeners_train_file: ${path.metadata_dir}/listeners.train.json
listeners_valid_file: ${path.metadata_dir}/listeners.valid.json
exp_folder: ./exp_${separator.model} # folder to store enhanced signals and final results
music_test_file: ${path.metadata_dir}/musdb18.test.json
music_segments_test_file: ${path.metadata_dir}/musdb18.segments.test.json
listeners_test_file: ${path.metadata_dir}/listeners.test.json

team_id: E001

sample_rate: 44100 # sample rate of the input mixture
stem_sample_rate: 24000 # sample rate output stems
Expand Down
247 changes: 247 additions & 0 deletions recipes/cad1/task1/baseline/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
""" Run the baseline enhancement. """
from __future__ import annotations

# pylint: disable=import-error
import json
import logging
import shutil
from pathlib import Path

import hydra
import numpy as np
import pandas as pd
import torch
from omegaconf import DictConfig
from scipy.io import wavfile
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB

from clarity.enhancer.compressor import Compressor
from clarity.enhancer.nalr import NALR
from recipes.cad1.task1.baseline.enhance import (
decompose_signal,
get_device,
process_stems_for_listener,
remix_signal,
save_flac_signal,
)
from recipes.cad1.task1.baseline.evaluate import make_song_listener_list

# pylint: disable=too-many-locals

logger = logging.getLogger(__name__)


def pack_submission(
team_id: str,
root_dir: str | Path,
base_dir: str | Path = ".",
) -> None:
"""
Pack the submission files into an archive file.
Args:
team_id (str): Team ID.
root_dir (str | Path): Root directory of the archived file.
base_dir (str | Path): Base directory to archive. Defaults to ".".
"""
# Pack the submission files
logger.info(f"Packing submission files for team {team_id}...")
shutil.make_archive(
f"submission_{team_id}",
"zip",
root_dir=root_dir,
base_dir=base_dir,
)


@hydra.main(config_path="", config_name="config")
def enhance(config: DictConfig) -> None:
"""
Run the music enhancement.
The system decomposes the music into vocal, drums, bass, and other stems.
Then, the NAL-R prescription procedure is applied to each stem.
Args:
config (dict): Dictionary of configuration options for enhancing music.
Returns 8 stems for each song:
- left channel vocal, drums, bass, and other stems
- right channel vocal, drums, bass, and other stems
"""

if config.separator.model not in ["demucs", "openunmix"]:
raise ValueError(f"Separator model {config.separator.model} not supported.")

enhanced_folder = Path("enhanced_signals") / "evaluation"
enhanced_folder.mkdir(parents=True, exist_ok=True)

if config.separator.model == "demucs":
separation_model = HDEMUCS_HIGH_MUSDB.get_model()
model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate
sources_order = separation_model.sources
normalise = True
elif config.separator.model == "openunmix":
separation_model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq", niter=0)
model_sample_rate = separation_model.sample_rate
sources_order = ["vocals", "drums", "bass", "other"]
normalise = False
else:
raise ValueError(f"Separator model {config.separator.model} not supported.")

device, _ = get_device(config.separator.device)
separation_model.to(device)

# Processing Validation Set
# Load listener audiograms and songs
with open(config.path.listeners_test_file, encoding="utf-8") as file:
listener_audiograms = json.load(file)

with open(config.path.music_test_file, encoding="utf-8") as file:
song_data = json.load(file)
songs_details = pd.DataFrame.from_dict(song_data)

with open(config.path.music_segments_test_file, encoding="utf-8") as file:
songs_segments = json.load(file)

song_listener_pairs = make_song_listener_list(
songs_details["Track Name"], listener_audiograms
)
# Select a batch to process
song_listener_pairs = song_listener_pairs[
config.evaluate.batch :: config.evaluate.batch_size
]

# Create hearing aid objects
enhancer = NALR(**config.nalr)
compressor = Compressor(**config.compressor)

# Decompose each song into left and right vocal, drums, bass, and other stems
# and process each stem for the listener
prev_song_name = None
num_song_list_pair = len(song_listener_pairs)
for idx, song_listener in enumerate(song_listener_pairs, 1):
song_name, listener_name = song_listener
logger.info(
f"[{idx:03d}/{num_song_list_pair:03d}] "
f"Processing {song_name} for {listener_name}..."
)
# Get the listener's audiogram
listener_info = listener_audiograms[listener_name]

# Find the music split directory
split_directory = (
"test"
if songs_details.loc[
songs_details["Track Name"] == song_name, "Split"
].iloc[0]
== "test"
else "train"
)

critical_frequencies = np.array(listener_info["audiogram_cfs"])
audiogram_left = np.array(listener_info["audiogram_levels_l"])
audiogram_right = np.array(listener_info["audiogram_levels_r"])

# Baseline Steps
# 1. Decompose the mixture signal into vocal, drums, bass, and other stems
# We validate if 2 consecutive signals are the same to avoid
# decomposing the same song multiple times
if prev_song_name != song_name:
# Decompose song only once
prev_song_name = song_name

sample_rate, mixture_signal = wavfile.read(
Path(config.path.music_dir)
/ split_directory
/ song_name
/ "mixture.wav"
)
mixture_signal = (mixture_signal / 32768.0).astype(np.float32).T
assert sample_rate == config.sample_rate

# Decompose mixture signal into stems
stems = decompose_signal(
separation_model,
model_sample_rate,
mixture_signal,
sample_rate,
device,
sources_order,
audiogram_left,
audiogram_right,
normalise,
)

# 2. Apply NAL-R prescription to each stem
# Baseline applies NALR prescription to each stem instead of using the
# listener's audiograms in the decomposition. This step can be skipped
# if the listener's audiograms are used in the decomposition
processed_stems = process_stems_for_listener(
stems,
enhancer,
compressor,
audiogram_left,
audiogram_right,
critical_frequencies,
config.apply_compressor,
)

# 3. Save processed stems
for stem_str, stem_signal in processed_stems.items():
filename = (
enhanced_folder
/ f"{listener_name}"
/ f"{song_name}"
/ f"{listener_name}_{song_name}_{stem_str}.flac"
)
filename.parent.mkdir(parents=True, exist_ok=True)
start = songs_segments[song_name]["objective_evaluation"]["start"]
end = songs_segments[song_name]["objective_evaluation"]["end"]
save_flac_signal(
signal=stem_signal[
int(start * config.sample_rate) : int(end * config.sample_rate)
],
filename=filename,
signal_sample_rate=config.sample_rate,
output_sample_rate=config.stem_sample_rate,
do_scale_signal=True,
)

# 3. Remix Signal
enhanced = remix_signal(processed_stems)

# 5. Save enhanced (remixed) signal
filename = (
enhanced_folder
/ f"{listener_info['name']}"
/ f"{song_name}"
/ f"{listener_info['name']}_{song_name}_remix.flac"
)
start = songs_segments[song_name]["subjective_evaluation"]["start"]
end = songs_segments[song_name]["subjective_evaluation"]["end"]
save_flac_signal(
signal=enhanced[
int(start * config.sample_rate) : int(end * config.sample_rate)
],
filename=filename,
signal_sample_rate=config.sample_rate,
output_sample_rate=config.remix_sample_rate,
do_clip_signal=True,
do_soft_clip=config.soft_clip,
)

pack_submission(
team_id=config.team_id,
root_dir=enhanced_folder.parent,
base_dir=enhanced_folder.name,
)

logger.info("Evaluation complete.!!")
logger.info(
f"Please, submit the file submission_{config.team_id}.zip to the challenge "
"using the link provided. Thank you.!!"
)


# pylint: disable = no-value-for-parameter
if __name__ == "__main__":
enhance()
7 changes: 7 additions & 0 deletions recipes/cad1/task2/baseline/baseline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def load_listeners_and_scenes(config: DictConfig) -> tuple[dict, dict, dict]:
listener_audiograms = json.load(fp)
scenes = df_scenes[df_scenes["split"] == "valid"].to_dict("index")

elif config.evaluate.split == "test":
with open(config.path.listeners_test_file, encoding="utf-8") as fp:
listener_audiograms = json.load(fp)
scenes = df_scenes[df_scenes["split"] == "test"].to_dict("index")
else:
raise ValueError(f"Unknown split {config.evaluate.split}")

with open(config.path.scenes_listeners_file, encoding="utf-8") as fp:
scenes_listeners = json.load(fp)
scenes_listeners = {
Expand Down
5 changes: 4 additions & 1 deletion recipes/cad1/task2/baseline/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ path:
scenes_listeners_file: ${path.metadata_dir}/scenes_listeners.json
hrtf_file: ${path.metadata_dir}/eBrird_BRIR.json
exp_folder: ./exp # folder to store enhanced signals and final results
listeners_test_file: ${path.metadata_dir}/listeners.test.json

team_id: E001

sample_rate: 44100 # sample rate of the input signal
enhanced_sample_rate: 32000 # sample rate for the enhanced output signal
Expand All @@ -35,7 +38,7 @@ evaluate:
set_random_seed: True
small_test: False
save_intermediate_wavs: False
split: valid # train, valid
split: test # train, valid, test
batch_size: 1 # Number of batches
batch: 0 # Batch number to evaluate

Expand Down
67 changes: 67 additions & 0 deletions recipes/cad1/task2/baseline/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
""" Run the dummy enhancement. """
# pylint: disable=too-many-locals
# pylint: disable=import-error
from __future__ import annotations

import logging
import shutil
from pathlib import Path

import hydra
from omegaconf import DictConfig

from recipes.cad1.task2.baseline.enhance import enhance as enhance_set

logger = logging.getLogger(__name__)


def pack_submission(
team_id: str,
root_dir: str | Path,
base_dir: str | Path = ".",
) -> None:
"""
Pack the submission files into an archive file.
Args:
team_id (str): Team ID.
root_dir (str | Path): Root directory of the archived file.
base_dir (str | Path): Base directory to archive. Defaults to ".".
"""
# Pack the submission files
logger.info(f"Packing submission files for team {team_id}...")
shutil.make_archive(
f"submission_{team_id}",
"zip",
root_dir=root_dir,
base_dir=base_dir,
)


@hydra.main(config_path="", config_name="config")
def enhance(config: DictConfig) -> None:
"""
Run the music enhancement.
The baseline system is a dummy processor that returns the input signal.
Args:
config (dict): Dictionary of configuration options for enhancing music.
"""
enhance_set(config)

pack_submission(
team_id=config.team_id,
root_dir=Path("enhanced_signals"),
base_dir=config.evaluate.split,
)

logger.info("Evaluation complete.!!")
logger.info(
f"Please, submit the file submission_{config.team_id}.zip to the challenge "
"using the link provided. Thank you.!!"
)


# pylint: disable = no-value-for-parameter
if __name__ == "__main__":
enhance()

0 comments on commit a6ce4a7

Please sign in to comment.