diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 53046f142..9b8dfca58 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -17,7 +17,7 @@ jobs: strategy: matrix: os: ["ubuntu-latest"] - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 - name: Set up Python diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 615f38fc0..39d4249f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 # Use the ref you want to point at + rev: v5.0.0 # Use the ref you want to point at hooks: - id: trailing-whitespace types: [file, text] @@ -20,25 +20,25 @@ repos: - id: check-toml - repo: https://github.com/asottile/pyupgrade - rev: v3.16.0 + rev: v3.18.0 hooks: - id: pyupgrade - args: [--py38-plus] + args: [--py39-plus] - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.10.0 hooks: - id: black types: [python] additional_dependencies: ["click==8.0.4"] - repo: https://github.com/DavidAnson/markdownlint-cli2 - rev: v0.13.0 + rev: v0.14.0 hooks: - id: markdownlint-cli2 - repo: https://github.com/pycqa/flake8.git - rev: 7.1.0 + rev: 7.1.1 hooks: - id: flake8 additional_dependencies: [flake8-print, Flake8-pyproject] @@ -52,7 +52,7 @@ repos: args: ["--profile", "black"] - repo: https://github.com/nbQA-dev/nbQA - rev: 1.8.5 + rev: 1.8.7 hooks: - id: nbqa-black - id: nbqa-flake8 @@ -65,24 +65,23 @@ repos: - id: nbstripout - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.1 + rev: v1.11.2 hooks: - id: mypy args: - --explicit-package-bases additional_dependencies: - - 'types-PyYAML' - + - "types-PyYAML" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: "v0.5.1" + rev: "v0.6.9" hooks: - id: ruff # Serious pylint errors that will be enforced by CI - repo: https://github.com/pycqa/pylint - rev: v3.2.5 + rev: v3.3.1 hooks: - id: pylint args: diff --git a/README.md b/README.md index 162780394..3a466fd66 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/claritychallenge/clarity/main.svg)](https://results.pre-commit.ci/latest/github/claritychallenge/clarity/main) [![Downloads](https://pepy.tech/badge/pyclarity)](https://pepy.tech/project/pyclarity) +[![PyPI](https://img.shields.io/static/v1?label=CAD2%20Challenge%20-%20pypi&message=v0.6.0&color=orange)](https://pypi.org/project/pyclarity/0.6.0/) [![PyPI](https://img.shields.io/static/v1?label=CEC3%20Challenge%20-%20pypi&message=v0.5.0&color=orange)](https://pypi.org/project/pyclarity/0.5.0/) [![PyPI](https://img.shields.io/static/v1?label=ICASSP%202024%20Cadenza%20Challenge%20-%20pypi&message=v0.4.1&color=orange)](https://pypi.org/project/pyclarity/0.4.1/) [![PyPI](https://img.shields.io/static/v1?label=CAD1%20and%20CPC2%20Challenges%20-%20pypi&message=v0.3.4&color=orange)](https://pypi.org/project/pyclarity/0.3.4/) @@ -37,14 +38,12 @@ In this repository, you will find code to support all Clarity and Cadenza Challe ## Current Events +- The 2nd Cadenza Challenge is now open :fire::fire: + - Visit the [cadenza website](https://cadenzachallenge.org/docs/cadenza2/intro) for more details. + - Join the [Cadenza Challenge Group](https://groups.google.com/g/cadenza-challenge) to keep up-to-date on developments. - The 3rd Clarity Enhancement Challenge is now open. :fire::fire: - Visit the [challenge website](https://claritychallenge.org/docs/cec3/cec3_intro) for more details. - Join the [Clarity Challenge Group](https://groups.google.com/g/clarity-challenge) to keep up-to-date on developments. -- The ICASSP 2024 Cadenza Challenge (CAD_ICASSP_2024) will be presented at ICASSP 2024. - - Join the [Cadenza Challenge Group](https://groups.google.com/g/cadenza-challenge) to keep up-to-date on developments. - - Visit the Cadenenza Challenge [website](https://cadenzachallenge.org/) for more details. -- The first Cadenza Challenge (CAD1) is closed. - - Subjective Evaluation is underway. :new: - The 2nd Clarity Prediction Challenge (CPC2) is now closed. - The 4th Clarity Workshop will be held as a satellite event of Interspeech 2023. For details visit the [workshop website](https://claritychallenge.org/clarity2023-workshop/). @@ -90,6 +89,7 @@ pip install -e git+https://github.com/claritychallenge/clarity.git@main Current challenge +- [The 2nd Cadenza Challege](./recipes/cad2) - [The 3rd Clarity Enhancement Challenge](./recipes/cec3) Previous challenges diff --git a/clarity/evaluator/msbg/cochlea.py b/clarity/evaluator/msbg/cochlea.py index 5d47bfbbc..10ff5cf67 100644 --- a/clarity/evaluator/msbg/cochlea.py +++ b/clarity/evaluator/msbg/cochlea.py @@ -224,7 +224,11 @@ class Cochlea: """ def __init__( - self, audiogram: Audiogram, catch_up_level: float = 105.0, fs: float = 44100.0 + self, + audiogram: Audiogram, + catch_up_level: float = 105.0, + fs: float = 44100.0, + verbose=True, ) -> None: """Cochlea constructor. @@ -233,6 +237,7 @@ def __init__( catch_up_level (float, optional): loudness catch-up level in dB Default is 105 dB fs (float, optional): sampling frequency + verbose (bool, optional): verbose mode. Default is True """ self.fs = fs @@ -254,7 +259,8 @@ def __init__( r_lower, r_upper = HL_PARAMS[severity_level]["smear_params"] self.smearer = Smearer(r_lower, r_upper, fs) - logging.info("Severity level - %s", severity_level) + if verbose: + logging.info("Severity level - %s", severity_level) def simulate(self, coch_sig: ndarray, equiv_0dB_file_SPL: float) -> ndarray: """Pass a signal through the cochlea. diff --git a/clarity/evaluator/msbg/msbg.py b/clarity/evaluator/msbg/msbg.py index 204144f09..fabd943fe 100644 --- a/clarity/evaluator/msbg/msbg.py +++ b/clarity/evaluator/msbg/msbg.py @@ -40,6 +40,7 @@ def __init__( sample_rate: float = 44100.0, equiv_0db_spl: float = 100.0, ahr: float = 20.0, + verbose: bool = True, ) -> None: """ Constructor for the Ear class. @@ -48,7 +49,9 @@ def __init__( sample_rate (float): sample frequency. equiv_0db_spl (): ??? ahr (): ??? + verbose (): ??? """ + self.verbose = verbose self.sample_rate = sample_rate self.src_correction = self.get_src_correction(src_pos) self.equiv_0db_spl = equiv_0db_spl @@ -62,7 +65,7 @@ def set_audiogram(self, audiogram: Audiogram) -> None: "Impairment too severe: Suggest you limit audiogram max to" "80-90 dB HL, otherwise things go wrong/unrealistic." ) - self.cochlea = Cochlea(audiogram=audiogram) + self.cochlea = Cochlea(audiogram=audiogram, verbose=self.verbose) @staticmethod def get_src_correction(src_pos: str) -> ndarray: @@ -92,6 +95,7 @@ def src_to_cochlea_filt( src_correction: ndarray, sample_rate: float, backward: bool = False, + verbose: bool = True, ) -> ndarray: """Simulate middle and outer ear transfer functions. @@ -109,12 +113,14 @@ def src_to_cochlea_filt( or ITU sample_rate (int): sampling frequency backward (bool, optional): if true then cochlea to src (default: False) + verbose (bool, optional): print verbose output (default: True) Returns: np.ndarray: the processed signal """ - logging.info("performing outer/middle ear corrections") + if verbose: + logging.info("performing outer/middle ear corrections") # make sure that response goes only up to sample_frequency/2 nyquist = int(sample_rate / 2.0) @@ -204,7 +210,8 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra ) raise ValueError("Invalid sampling frequency, valid value is 44100") - logging.info("Processing {len(chans)} samples") + if self.verbose: + logging.info("Processing {len(chans)} samples") # Need to know file RMS, and then call that a certain level in SPL: # needs some form of pre-measuring. @@ -219,7 +226,7 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra # Measure RMS where 3rd arg is dB_rel_rms (how far below) calculated_rms, idx, _rel_db_thresh, _active = measure_rms( - signal[0], sample_rate, -12 + signal[0], sample_rate, -12, verbose=self.verbose ) # Rescale input data and check level after rescaling @@ -229,11 +236,11 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra new_rms_db = equiv_0db_spl + 10 * np.log10( np.mean(np.power(signal[0][idx], 2.0)) ) - logging.info( - "Rescaling: " - f"leveldBSPL was {level_db_spl:3.1f} dB SPL, now {new_rms_db:3.1f} dB SPL. " - f" Target SPL is {target_spl:3.1f} dB SPL." - ) + if self.verbose: + logging.info( + f"Rescaling: leveldBSPL was {level_db_spl:3.1f} dB SPL, now" + f" {new_rms_db:3.1f} dB SPL. Target SPL is {target_spl:3.1f} dB SPL." + ) # Add calibration signal at target SPL dB if add_calibration is True: @@ -247,11 +254,17 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra signal = np.concatenate((pre_calibration, signal, post_calibration), axis=1) # Transform from src pos to cochlea, simulate cochlea, transform back to src pos - signal = Ear.src_to_cochlea_filt(signal, self.src_correction, sample_rate) + signal = Ear.src_to_cochlea_filt( + signal, self.src_correction, sample_rate, verbose=self.verbose + ) if self.cochlea is not None: signal = np.array([self.cochlea.simulate(x, equiv_0db_spl) for x in signal]) signal = Ear.src_to_cochlea_filt( - signal, self.src_correction, sample_rate, backward=True + signal, + self.src_correction, + sample_rate, + backward=True, + verbose=self.verbose, ) # Implement low-pass filter at top end of audio range: flat to Cutoff freq, diff --git a/clarity/evaluator/msbg/msbg_utils.py b/clarity/evaluator/msbg/msbg_utils.py index 935d920a5..0787133aa 100644 --- a/clarity/evaluator/msbg/msbg_utils.py +++ b/clarity/evaluator/msbg/msbg_utils.py @@ -10,8 +10,8 @@ import numpy as np import scipy -import scipy.signal from numpy import ndarray +from scipy.signal.windows import hamming, kaiser # measure rms parameters WIN_SECS: Final = 0.01 @@ -167,7 +167,7 @@ def firwin2( order = n_taps - 1 if window_type == "kaiser": - window_shape = scipy.signal.kaiser(n_taps, window_param) + window_shape = kaiser(n_taps, window_param) if window_shape is None: filter_coef, _ = fir2(order, frequencies, filter_gains) @@ -203,7 +203,7 @@ def fir2( filter_length += 1 if window_shape is None: - window_shape = scipy.signal.hamming(filter_length) + window_shape = hamming(filter_length) n_interpolate = ( 2 ** np.ceil(math.log(filter_length) / math.log(2.0)) @@ -358,6 +358,7 @@ def generate_key_percent( threshold_db: float, window_length: int, percent_to_track: float | None = None, + verbose: bool = False, ) -> tuple[ndarray, float]: """Generate key percent. Locates frames above some energy threshold or tracks a certain percentage @@ -370,6 +371,7 @@ def generate_key_percent( window_length (int): length of window in samples. percent_to_track (float, optional): Track a percentage of frames. Default is None + verbose (bool, optional): Print verbose output. Default is False. Raises: ValueError: percent_to_track is set too high. @@ -393,10 +395,11 @@ def generate_key_percent( expected = threshold_db # new Dec 2003. Possibly track percentage of frames rather than fixed threshold - if percent_to_track is not None: - logging.info("tracking %s percentage of frames", percent_to_track) - else: - logging.info("tracking fixed threshold") + if verbose: + if percent_to_track is not None: + logging.info("tracking %s percentage of frames", percent_to_track) + else: + logging.info("tracking fixed threshold") # put floor into histogram distribution non_zero = np.power(10, (expected - 30) / 10) @@ -466,6 +469,7 @@ def measure_rms( sample_rate: float, db_rel_rms: float, percent_to_track: float | None = None, + verbose: bool = False, ) -> tuple[float, ndarray, float, float]: """Measure Root Mean Square. @@ -481,6 +485,7 @@ def measure_rms( db_rel_rms (float): threshold for frames to track. percent_to_track (float, optional): track percentage of frames, rather than threshold (default: {None}) + verbose (bool, optional): Print verbose output. Default is False. Returns: (tuple): tuple containing - rms (float): overall calculated rms (linear) @@ -500,6 +505,7 @@ def measure_rms( key_thr_db, round(WIN_SECS * sample_rate), percent_to_track=percent_to_track, + verbose=verbose, ) idx = key.astype(int) # move into generate_key_percent diff --git a/docs/conf.py b/docs/conf.py index 34def907a..f1e365e2b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -7,7 +7,6 @@ import os import sys from importlib.metadata import version -from typing import Dict # -*- coding: utf-8 -*- # @@ -157,7 +156,7 @@ # -- Options for LaTeX output ------------------------------------------------ -latex_elements: Dict[str, str] = { +latex_elements: dict[str, str] = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', diff --git a/notebooks/01_Installing_clarity_tools_and_using_metadata.ipynb b/notebooks/01_Installing_clarity_tools_and_using_metadata.ipynb index 3edb3b2b0..1a3445860 100644 --- a/notebooks/01_Installing_clarity_tools_and_using_metadata.ipynb +++ b/notebooks/01_Installing_clarity_tools_and_using_metadata.ipynb @@ -63,12 +63,13 @@ }, "outputs": [], "source": [ + "import os\n", + "import sys\n", + "\n", "print(\"Changing directory...\")\n", "%cd clarity\n", "print(\"Installing Clarity tools\")\n", "%pip install -e .\n", - "import os\n", - "import sys\n", "\n", "sys.path.append(os.getcwd())\n", "print(\"Moving back to project root directory\")\n", @@ -414,11 +415,13 @@ "\n", "\n", "print(\n", - " f'\\nScene number {scene_no} (ID {scene[\"scene\"]}) has room dimensions of {room[\"dimensions\"]}'\n", + " f\"\\nScene number {scene_no} \"\n", + " f'(ID {scene[\"scene\"]}) has room dimensions of {room[\"dimensions\"]}'\n", ")\n", "\n", "print(\n", - " f'\\nSimulated listeners for scene {scene_no} have spatial attributes: \\n{room[\"listener\"]}'\n", + " f\"\\nSimulated listeners for scene {scene_no} \"\n", + " f'have spatial attributes: \\n{room[\"listener\"]}'\n", ")\n", "\n", "print(f'\\nAudiograms for listeners in Scene ID {scene[\"scene\"]}')\n", @@ -427,8 +430,8 @@ "fig, ax = plt.subplots(1, len(current_listeners))\n", "\n", "ax[0].set_ylabel(\"Hearing level (dB)\")\n", - "for i, l in enumerate(current_listeners):\n", - " listener_data = listeners[l]\n", + "for i, curr_listener in enumerate(current_listeners):\n", + " listener_data = listeners[curr_listener]\n", " (left_ag,) = ax[i].plot(\n", " listener_data[\"audiogram_cfs\"],\n", " -np.array(listener_data[\"audiogram_levels_l\"]),\n", @@ -439,7 +442,7 @@ " -np.array(listener_data[\"audiogram_levels_r\"]),\n", " label=\"right audiogram\",\n", " )\n", - " ax[i].set_title(f\"Listener {l}\")\n", + " ax[i].set_title(f\"Listener {curr_listener}\")\n", " ax[i].set_xlabel(\"Hz\")\n", " ax[i].set_ylim([-100, 10])\n", "\n", diff --git a/notebooks/02_Running_the_CEC2_baseline_from_commandline.ipynb b/notebooks/02_Running_the_CEC2_baseline_from_commandline.ipynb index 11d9a79ab..7003fd5e6 100644 --- a/notebooks/02_Running_the_CEC2_baseline_from_commandline.ipynb +++ b/notebooks/02_Running_the_CEC2_baseline_from_commandline.ipynb @@ -84,15 +84,17 @@ }, "outputs": [], "source": [ + "import os\n", + "import sys\n", + "\n", + "from IPython.display import clear_output\n", + "\n", "print(\"Cloning git repo...\")\n", "!git clone --quiet https://github.com/claritychallenge/clarity.git\n", "%cd clarity\n", "%pip install -e .\n", - "import os\n", - "import sys\n", "\n", "sys.path.append(f'{os.getenv(\"NBOOKROOT\")}/clarity')\n", - "from IPython.display import clear_output\n", "\n", "clear_output()\n", "print(\"Repository installed\")" diff --git a/notebooks/03_Running_the_CEC2_baseline_from_python.ipynb b/notebooks/03_Running_the_CEC2_baseline_from_python.ipynb index 7cbab869a..8bac688c5 100644 --- a/notebooks/03_Running_the_CEC2_baseline_from_python.ipynb +++ b/notebooks/03_Running_the_CEC2_baseline_from_python.ipynb @@ -48,13 +48,14 @@ }, "outputs": [], "source": [ + "import os\n", + "import sys\n", + "\n", "print(\"Cloning git repo...\")\n", "\n", "!git clone --quiet https://github.com/claritychallenge/clarity.git\n", "%cd clarity\n", "%pip install -e .\n", - "import os\n", - "import sys\n", "\n", "sys.path.append(os.getcwd())\n", "%cd .." diff --git a/pyproject.toml b/pyproject.toml index 1d24c0dcb..0fb977630 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Natural Language :: English", ] keywords = ["hearing", "signal processing", "clarity challenge"] -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [ "audioread>=2.1.9", "gdown", @@ -30,10 +30,10 @@ dependencies = [ "importlib-metadata", "librosa>=0.8.1", "matplotlib", - "numba>=0.57.0rc", - "numpy>=1.21.6", + "numba>=0.60", + "numpy>=2", "omegaconf>=2.1.1", - "pandas>=1.3.5", + "pandas>=2.2.2", "pyflac", "pyloudnorm>=0.1.0", "pystoi", @@ -41,9 +41,9 @@ dependencies = [ "resampy", "safetensors>=0.4.3", "scikit-learn>=1.0.2", - "scipy>=1.7.3, <1.13.0", + "scipy>=1.7.3", "SoundFile>=0.10.3.post1", - "soxr", + "soxr>=0.4", "torch>=2", "torchaudio", "tqdm>=4.62.3", diff --git a/recipes/cad1/task1/baseline/config.yaml b/recipes/cad1/task1/baseline/config.yaml index bfac09086..ae51a4099 100644 --- a/recipes/cad1/task1/baseline/config.yaml +++ b/recipes/cad1/task1/baseline/config.yaml @@ -40,4 +40,6 @@ evaluate: # hydra config hydra: run: - dir: ${path.exp_folder} \ No newline at end of file + dir: ${path.exp_folder} + job: + chdir: True \ No newline at end of file diff --git a/recipes/cad1/task1/baseline/enhance.py b/recipes/cad1/task1/baseline/enhance.py index 4b48535f2..fc389054c 100644 --- a/recipes/cad1/task1/baseline/enhance.py +++ b/recipes/cad1/task1/baseline/enhance.py @@ -347,7 +347,7 @@ def save_flac_signal( FlacEncoder().encode(signal, output_sample_rate, filename) -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def enhance(config: DictConfig) -> None: """ Run the music enhancement. diff --git a/recipes/cad1/task1/baseline/evaluate.py b/recipes/cad1/task1/baseline/evaluate.py index 9994ff1d2..b985a9b9e 100644 --- a/recipes/cad1/task1/baseline/evaluate.py +++ b/recipes/cad1/task1/baseline/evaluate.py @@ -152,7 +152,7 @@ def _evaluate_song_listener( return float(combined_score), per_instrument_score -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def run_calculate_aq(config: DictConfig) -> None: """Evaluate the enhanced signals using the HAAQI-RMS metric.""" # Load test songs diff --git a/recipes/cad1/task1/baseline/merge_batches_results.py b/recipes/cad1/task1/baseline/merge_batches_results.py index 451510065..01d3e0a2f 100644 --- a/recipes/cad1/task1/baseline/merge_batches_results.py +++ b/recipes/cad1/task1/baseline/merge_batches_results.py @@ -5,7 +5,7 @@ from omegaconf import DictConfig -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def join_batches(config: DictConfig) -> None: """ Join batches scores into a single file. diff --git a/recipes/cad1/task1/baseline/test.py b/recipes/cad1/task1/baseline/test.py index fe61ebbeb..f976f28e7 100644 --- a/recipes/cad1/task1/baseline/test.py +++ b/recipes/cad1/task1/baseline/test.py @@ -57,7 +57,7 @@ def pack_submission( ) -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def enhance(config: DictConfig) -> None: """ Run the music enhancement. diff --git a/recipes/cad1/task2/baseline/config.yaml b/recipes/cad1/task2/baseline/config.yaml index ff84fe766..720c4d816 100644 --- a/recipes/cad1/task2/baseline/config.yaml +++ b/recipes/cad1/task2/baseline/config.yaml @@ -43,4 +43,6 @@ evaluate: # hydra config hydra: run: - dir: ${path.exp_folder} \ No newline at end of file + dir: ${path.exp_folder} + job: + chdir: True \ No newline at end of file diff --git a/recipes/cad1/task2/baseline/enhance.py b/recipes/cad1/task2/baseline/enhance.py index 263d472f3..d829c1b9b 100644 --- a/recipes/cad1/task2/baseline/enhance.py +++ b/recipes/cad1/task2/baseline/enhance.py @@ -92,7 +92,7 @@ def enhance_song( return out_left, out_right -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def enhance(config: DictConfig) -> None: """ Run the music enhancement. diff --git a/recipes/cad1/task2/baseline/evaluate.py b/recipes/cad1/task2/baseline/evaluate.py index 37af6f484..eb30fa20c 100644 --- a/recipes/cad1/task2/baseline/evaluate.py +++ b/recipes/cad1/task2/baseline/evaluate.py @@ -136,7 +136,7 @@ def evaluate_scene( return aq_score_l, aq_score_r -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def run_calculate_audio_quality(config: DictConfig) -> None: """Evaluate the enhanced signals using the HAAQI metric.""" diff --git a/recipes/cad1/task2/baseline/merge_batches_results.py b/recipes/cad1/task2/baseline/merge_batches_results.py index e5f2e58c3..806a8d71a 100644 --- a/recipes/cad1/task2/baseline/merge_batches_results.py +++ b/recipes/cad1/task2/baseline/merge_batches_results.py @@ -5,7 +5,7 @@ from omegaconf import DictConfig -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def join_batches(config: DictConfig) -> None: """ Join batches scores into a single file. diff --git a/recipes/cad1/task2/baseline/test.py b/recipes/cad1/task2/baseline/test.py index d7be33ebe..bfcb425cf 100644 --- a/recipes/cad1/task2/baseline/test.py +++ b/recipes/cad1/task2/baseline/test.py @@ -39,7 +39,7 @@ def pack_submission( ) -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def enhance(config: DictConfig) -> None: """ Run the music enhancement. diff --git a/recipes/cad1/task2/data_preparation/build_scene_metadata.py b/recipes/cad1/task2/data_preparation/build_scene_metadata.py index a3ec173fd..4221829e8 100644 --- a/recipes/cad1/task2/data_preparation/build_scene_metadata.py +++ b/recipes/cad1/task2/data_preparation/build_scene_metadata.py @@ -83,7 +83,7 @@ def get_random_snr(min_snr, max_snr, round_to=4) -> float: return float(np.round(np.random.uniform(min_snr, max_snr, 1), round_to)) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: """Main function to generate metadata for the scenes in the CAD-1 Task-2 challenge. diff --git a/recipes/cad1/task2/data_preparation/config.yaml b/recipes/cad1/task2/data_preparation/config.yaml index b32302149..541f5d09b 100644 --- a/recipes/cad1/task2/data_preparation/config.yaml +++ b/recipes/cad1/task2/data_preparation/config.yaml @@ -19,3 +19,5 @@ valid_seed: 2023 hydra: run: dir: . + job: + chdir: True \ No newline at end of file diff --git a/recipes/cad2/task1/baseline/config.yaml b/recipes/cad2/task1/baseline/config.yaml index 8803f9a8c..047688475 100644 --- a/recipes/cad2/task1/baseline/config.yaml +++ b/recipes/cad2/task1/baseline/config.yaml @@ -10,7 +10,7 @@ path: scene_listeners_file: ${path.metadata_dir}/scene_listeners.valid.json exp_folder: ./exp_${separator.causality} # folder to store enhanced signals and final results -input_sample_rate: 44100 # sample rate of the input mixture +input_sample_rate: 44100 # sample rate of the input mixture remix_sample_rate: 44100 # sample rate for the output remixed signal HAAQI_sample_rate: 24000 # sample rate for computing HAAQI score diff --git a/recipes/cad2/task1/baseline/evaluate.py b/recipes/cad2/task1/baseline/evaluate.py index a7b7b764d..22c4c7bca 100644 --- a/recipes/cad2/task1/baseline/evaluate.py +++ b/recipes/cad2/task1/baseline/evaluate.py @@ -12,7 +12,7 @@ import pyloudnorm as pyln import torch.nn import whisper -from jiwer import compute_measures +from alt_eval import compute_metrics from omegaconf import DictConfig from clarity.enhancer.multiband_compressor import MultibandCompressor @@ -98,6 +98,7 @@ def compute_intelligibility( ear = Ear( equiv_0db_spl=equiv_0db_spl, sample_rate=sample_rate, + verbose=False, ) reference = segment_metadata["text"] @@ -116,7 +117,9 @@ def compute_intelligibility( hypothesis = scorer.transcribe(left_path.as_posix(), fp16=False)["text"] lyrics["hypothesis_left"] = hypothesis - left_results = compute_measures(reference, hypothesis) + left_results = compute_metrics( + [reference], [hypothesis], languages="en", include_other=False + ) # Compute right ear ear.set_audiogram(listener.audiogram_right) @@ -131,7 +134,10 @@ def compute_intelligibility( hypothesis = scorer.transcribe(right_path.as_posix(), fp16=False)["text"] lyrics["hypothesis_right"] = hypothesis - right_results = compute_measures(reference, hypothesis) + right_results = compute_metrics( + [reference], [hypothesis], languages="en", include_other=False + ) + # Compute the average score for both ears total_words = ( @@ -403,7 +409,6 @@ def run_compute_scores(config: DictConfig) -> None: ) max_whisper = np.max([whisper_left, whisper_right]) - mean_haaqi = np.mean(haaqi_scores) results_file.add_result( { "scene": scene_id, @@ -414,12 +419,14 @@ def run_compute_scores(config: DictConfig) -> None: "hypothesis_right": lyrics_text["hypothesis_right"], "haaqi_left": haaqi_scores[0], "haaqi_right": haaqi_scores[1], - "haaqi_avg": mean_haaqi, + "haaqi_avg": np.mean(haaqi_scores), "whisper_left": whisper_left, "whisper_rigth": whisper_right, "whisper_be": max_whisper, "alpha": alpha, - "score": alpha * max_whisper + (1 - alpha) * mean_haaqi, + "score": alpha * max_whisper + (1 - alpha) * np.mean(haaqi_scores), + + } ) diff --git a/recipes/cad2/task1/requirements.txt b/recipes/cad2/task1/requirements.txt index 12ce2cb5e..14359e1fc 100644 --- a/recipes/cad2/task1/requirements.txt +++ b/recipes/cad2/task1/requirements.txt @@ -1,4 +1,4 @@ +alt-eval huggingface-hub -jiwer openai-whisper safetensors diff --git a/recipes/cad2/task2/baseline/config.yaml b/recipes/cad2/task2/baseline/config.yaml index 2e34f1cd4..4311c5cc1 100644 --- a/recipes/cad2/task2/baseline/config.yaml +++ b/recipes/cad2/task2/baseline/config.yaml @@ -1,5 +1,5 @@ path: - root: ?? # Set to the root of the dataset + root: ??? # Set to the root of the dataset metadata_dir: ${path.root}/metadata music_dir: ${path.root}/audio gains_file: ${path.metadata_dir}/gains.json diff --git a/recipes/cad_icassp_2024/baseline/config.yaml b/recipes/cad_icassp_2024/baseline/config.yaml index 6038f16a0..1270ddea9 100644 --- a/recipes/cad_icassp_2024/baseline/config.yaml +++ b/recipes/cad_icassp_2024/baseline/config.yaml @@ -41,4 +41,6 @@ evaluate: # hydra config hydra: run: - dir: ${path.exp_folder} \ No newline at end of file + dir: ${path.exp_folder} + job: + chdir: True \ No newline at end of file diff --git a/recipes/cad_icassp_2024/baseline/enhance.py b/recipes/cad_icassp_2024/baseline/enhance.py index d6fec3144..4ae6aa69d 100644 --- a/recipes/cad_icassp_2024/baseline/enhance.py +++ b/recipes/cad_icassp_2024/baseline/enhance.py @@ -180,7 +180,7 @@ def process_remix_for_listener( return np.stack([left_output, right_output], axis=1) -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def enhance(config: DictConfig) -> None: """ Run the music enhancement. diff --git a/recipes/cad_icassp_2024/baseline/evaluate.py b/recipes/cad_icassp_2024/baseline/evaluate.py index a40e047e7..408da477e 100644 --- a/recipes/cad_icassp_2024/baseline/evaluate.py +++ b/recipes/cad_icassp_2024/baseline/evaluate.py @@ -188,7 +188,7 @@ def load_reference_stems(music_dir: str | Path) -> tuple[dict[str, ndarray], nda return reference_stems, read_signal(Path(music_dir) / "mixture.wav") -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def run_calculate_aq(config: DictConfig) -> None: """Evaluate the enhanced signals using the HAAQI metric.""" diff --git a/recipes/cad_icassp_2024/baseline/merge_batches_results.py b/recipes/cad_icassp_2024/baseline/merge_batches_results.py index ff29762da..20cd873e1 100644 --- a/recipes/cad_icassp_2024/baseline/merge_batches_results.py +++ b/recipes/cad_icassp_2024/baseline/merge_batches_results.py @@ -6,7 +6,7 @@ from omegaconf import DictConfig -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def join_batches(config: DictConfig) -> None: """ Join batches scores into a single file. diff --git a/recipes/cad_icassp_2024/generate_dataset/config.yaml b/recipes/cad_icassp_2024/generate_dataset/config.yaml index 947715efa..98c20ae6e 100644 --- a/recipes/cad_icassp_2024/generate_dataset/config.yaml +++ b/recipes/cad_icassp_2024/generate_dataset/config.yaml @@ -22,3 +22,5 @@ scene_listener: hydra: run: dir: . + job: + chdir: True diff --git a/recipes/cad_icassp_2024/generate_dataset/generate_at_mic_musdb18.py b/recipes/cad_icassp_2024/generate_dataset/generate_at_mic_musdb18.py index 7d3a84a49..09ea35913 100644 --- a/recipes/cad_icassp_2024/generate_dataset/generate_at_mic_musdb18.py +++ b/recipes/cad_icassp_2024/generate_dataset/generate_at_mic_musdb18.py @@ -127,7 +127,7 @@ def find_precreated_samples(source_dir: str | Path) -> list[str]: return previous_tracks -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: """Main function of the script.""" diff --git a/recipes/cad_icassp_2024/generate_dataset/generate_train_scenes.py b/recipes/cad_icassp_2024/generate_dataset/generate_train_scenes.py index 7f85aa590..173f384a4 100644 --- a/recipes/cad_icassp_2024/generate_dataset/generate_train_scenes.py +++ b/recipes/cad_icassp_2024/generate_dataset/generate_train_scenes.py @@ -105,7 +105,7 @@ def generate_scene_listener(cfg: DictConfig) -> None: json.dump(scene_listeners, f, indent=4) -@hydra.main(config_path="", config_name="config") +@hydra.main(config_path="", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: """Module generates the scenes and scene-listeners metadata files for training.""" generate_scenes(cfg) diff --git a/recipes/cec1/baseline/config.yaml b/recipes/cec1/baseline/config.yaml index bf7982c5e..46c09a647 100644 --- a/recipes/cec1/baseline/config.yaml +++ b/recipes/cec1/baseline/config.yaml @@ -42,3 +42,5 @@ hydra: output_subdir: Null run: dir: . + job: + chdir: True diff --git a/recipes/cec1/baseline/enhance.py b/recipes/cec1/baseline/enhance.py index 595c4833d..4a3043cbf 100644 --- a/recipes/cec1/baseline/enhance.py +++ b/recipes/cec1/baseline/enhance.py @@ -9,7 +9,7 @@ from clarity.utils.audiogram import Listener -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def enhance(cfg: DictConfig) -> None: enhanced_folder = Path(cfg.path.exp_folder) / "enhanced_signals" enhanced_folder.mkdir(parents=True, exist_ok=True) diff --git a/recipes/cec1/baseline/evaluate.py b/recipes/cec1/baseline/evaluate.py index 5e6b66e1c..95ba8a9dc 100644 --- a/recipes/cec1/baseline/evaluate.py +++ b/recipes/cec1/baseline/evaluate.py @@ -40,7 +40,7 @@ def listen(ear, signal: ndarray, listener: Listener): return np.concatenate([out_l, out_r]).T -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run_HL_processing(cfg: DictConfig) -> None: output_path = Path(cfg.path.exp_folder) / "eval_signals" output_path.mkdir(parents=True, exist_ok=True) diff --git a/recipes/cec1/data_preparation/data_config.yaml b/recipes/cec1/data_preparation/data_config.yaml index df81a3cfa..64c0572a8 100644 --- a/recipes/cec1/data_preparation/data_config.yaml +++ b/recipes/cec1/data_preparation/data_config.yaml @@ -19,4 +19,6 @@ defaults: hydra: output_subdir: Null run: - dir: . \ No newline at end of file + dir: . + job: + chdir: True \ No newline at end of file diff --git a/recipes/cec1/data_preparation/prepare_cec1_data.py b/recipes/cec1/data_preparation/prepare_cec1_data.py index a40df4234..330b8772a 100644 --- a/recipes/cec1/data_preparation/prepare_cec1_data.py +++ b/recipes/cec1/data_preparation/prepare_cec1_data.py @@ -44,7 +44,7 @@ def prepare_data( ) -@hydra.main(config_path=".", config_name="data_config") +@hydra.main(config_path=".", config_name="data_config", version_base=None) def run(cfg: DictConfig) -> None: for dataset in cfg["datasets"]: prepare_data( diff --git a/recipes/cec1/e009_sheffield/config.yaml b/recipes/cec1/e009_sheffield/config.yaml index 830531551..a58213e92 100644 --- a/recipes/cec1/e009_sheffield/config.yaml +++ b/recipes/cec1/e009_sheffield/config.yaml @@ -101,3 +101,5 @@ hydra: # output_subdir: ${path.exp_folder}.hydra run: dir: ${path.exp_folder} + job: + chdir: True diff --git a/recipes/cec1/e009_sheffield/test.py b/recipes/cec1/e009_sheffield/test.py index 5282c8f46..ff4add891 100644 --- a/recipes/cec1/e009_sheffield/test.py +++ b/recipes/cec1/e009_sheffield/test.py @@ -14,7 +14,7 @@ from clarity.enhancer.dsp.filter import AudiometricFIR -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: exp_folder = Path(cfg.path.exp_folder) output_folder = exp_folder / f"enhanced_{cfg.listener.id}" diff --git a/recipes/cec1/e009_sheffield/train.py b/recipes/cec1/e009_sheffield/train.py index 56b91b636..631c16505 100644 --- a/recipes/cec1/e009_sheffield/train.py +++ b/recipes/cec1/e009_sheffield/train.py @@ -231,7 +231,7 @@ def train_amp(cfg, ear): torch.save(amp_module.model.state_dict(), str(exp_dir / "best_model.pth")) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: logger.info("Begin training left ear enhancement module.") train_den(cfg, ear="left") diff --git a/recipes/cec2/baseline/config.yaml b/recipes/cec2/baseline/config.yaml index 0356ee264..5acd5eda9 100644 --- a/recipes/cec2/baseline/config.yaml +++ b/recipes/cec2/baseline/config.yaml @@ -27,3 +27,5 @@ evaluate: hydra: run: dir: ${path.exp_folder} + job: + chdir: True diff --git a/recipes/cec2/baseline/data_generation/additional_data_config.yaml b/recipes/cec2/baseline/data_generation/additional_data_config.yaml index cf3264bdf..1124ed0f4 100644 --- a/recipes/cec2/baseline/data_generation/additional_data_config.yaml +++ b/recipes/cec2/baseline/data_generation/additional_data_config.yaml @@ -74,6 +74,8 @@ render_params: hydra: run: dir: . + job: + chdir: True defaults: - override hydra/launcher: cec2_submitit_local diff --git a/recipes/cec2/baseline/data_generation/build_additional_scenes.py b/recipes/cec2/baseline/data_generation/build_additional_scenes.py index e37889bee..82bc14394 100644 --- a/recipes/cec2/baseline/data_generation/build_additional_scenes.py +++ b/recipes/cec2/baseline/data_generation/build_additional_scenes.py @@ -32,7 +32,7 @@ def instantiate_scenes(cfg): logger.info(f"scenes.{dataset}.json has existed, skip") -@hydra.main(config_path=".", config_name="additional_data_config") +@hydra.main(config_path=".", config_name="additional_data_config", version_base=None) def run(cfg: DictConfig) -> None: logger.info("Instantiating scenes for additional training data") instantiate_scenes(cfg) diff --git a/recipes/cec2/baseline/data_generation/render_additional_scenes.py b/recipes/cec2/baseline/data_generation/render_additional_scenes.py index 4788736b6..18c50090c 100644 --- a/recipes/cec2/baseline/data_generation/render_additional_scenes.py +++ b/recipes/cec2/baseline/data_generation/render_additional_scenes.py @@ -32,7 +32,7 @@ def render_scenes(cfg): scene_renderer.render_scenes(scenes) -@hydra.main(config_path=".", config_name="additional_data_config") +@hydra.main(config_path=".", config_name="additional_data_config", version_base=None) def run(cfg: DictConfig) -> None: logger.info("Rendering scenes") render_scenes(cfg) diff --git a/recipes/cec2/baseline/enhance.py b/recipes/cec2/baseline/enhance.py index 99267916c..9b9c9d374 100644 --- a/recipes/cec2/baseline/enhance.py +++ b/recipes/cec2/baseline/enhance.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def enhance(cfg: DictConfig) -> None: enhanced_folder = Path(cfg.path.exp_folder) / "enhanced_signals" enhanced_folder.mkdir(parents=True, exist_ok=True) diff --git a/recipes/cec2/baseline/evaluate.py b/recipes/cec2/baseline/evaluate.py index db56c2239..82e92b071 100644 --- a/recipes/cec2/baseline/evaluate.py +++ b/recipes/cec2/baseline/evaluate.py @@ -3,7 +3,6 @@ import json import logging from pathlib import Path -from typing import Dict import hydra import numpy as np @@ -17,7 +16,7 @@ logger = logging.getLogger(__name__) -def read_csv_scores(file: Path) -> Dict[str, float]: +def read_csv_scores(file: Path) -> dict[str, float]: score_dict = {} with file.open("r", encoding="utf-8") as fp: reader = csv.reader(fp) @@ -27,7 +26,7 @@ def read_csv_scores(file: Path) -> Dict[str, float]: return score_dict -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run_calculate_SI(cfg: DictConfig) -> None: with Path(cfg.path.scenes_listeners_file).open("r", encoding="utf-8") as fp: scenes_listeners = json.load(fp) diff --git a/recipes/cec2/data_preparation/build_scenes.py b/recipes/cec2/data_preparation/build_scenes.py index 09581959d..232aae697 100644 --- a/recipes/cec2/data_preparation/build_scenes.py +++ b/recipes/cec2/data_preparation/build_scenes.py @@ -44,7 +44,7 @@ def instantiate_scenes(cfg): logger.info(f"scenes.{dataset}.json exists, skip") -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: logger.info("Building rooms") build_rooms_from_rpf(cfg) diff --git a/recipes/cec2/data_preparation/config.yaml b/recipes/cec2/data_preparation/config.yaml index a09bd8536..17794cb8e 100644 --- a/recipes/cec2/data_preparation/config.yaml +++ b/recipes/cec2/data_preparation/config.yaml @@ -124,6 +124,8 @@ render_params: hydra: run: dir: . + job: + chdir: True defaults: - override hydra/launcher: cec2_submitit_local diff --git a/recipes/cec2/data_preparation/render_scenes.py b/recipes/cec2/data_preparation/render_scenes.py index faee13442..1b2c63815 100644 --- a/recipes/cec2/data_preparation/render_scenes.py +++ b/recipes/cec2/data_preparation/render_scenes.py @@ -34,7 +34,7 @@ def render_scenes(cfg): scene_renderer.render_scenes(scenes) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: logger.info("Rendering scenes") render_scenes(cfg) diff --git a/recipes/cpc1/baseline/compute_scores.py b/recipes/cpc1/baseline/compute_scores.py index d2b6f6835..8c6f0efaa 100644 --- a/recipes/cpc1/baseline/compute_scores.py +++ b/recipes/cpc1/baseline/compute_scores.py @@ -81,7 +81,7 @@ def read_data(pred_csv: Path, label_json: Path): return data -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: logger.info("Run evaluation on the closed set.") data_tr = read_data( diff --git a/recipes/cpc1/baseline/config.yaml b/recipes/cpc1/baseline/config.yaml index 0883ddf8a..1c9766591 100644 --- a/recipes/cpc1/baseline/config.yaml +++ b/recipes/cpc1/baseline/config.yaml @@ -54,3 +54,5 @@ hydra: output_subdir: Null run: dir: . + job: + chdir: True diff --git a/recipes/cpc1/baseline/run.py b/recipes/cpc1/baseline/run.py index fb084ef0a..9188b304e 100644 --- a/recipes/cpc1/baseline/run.py +++ b/recipes/cpc1/baseline/run.py @@ -153,7 +153,7 @@ def run_calculate_SI(cfg, path) -> None: csv_writer.writerow(line) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: logger.info("Prediction with MSGB + MBSTOI for train set") run_HL_processing(cfg, cfg.train_path) diff --git a/recipes/cpc1/e029_sheffield/config.yaml b/recipes/cpc1/e029_sheffield/config.yaml index c003ae151..2c5423fb1 100644 --- a/recipes/cpc1/e029_sheffield/config.yaml +++ b/recipes/cpc1/e029_sheffield/config.yaml @@ -19,3 +19,5 @@ hydra: output_subdir: Null run: dir: . + job: + chdir: True diff --git a/recipes/cpc1/e029_sheffield/evaluate.py b/recipes/cpc1/e029_sheffield/evaluate.py index 2e9152abd..a08527d24 100644 --- a/recipes/cpc1/e029_sheffield/evaluate.py +++ b/recipes/cpc1/e029_sheffield/evaluate.py @@ -84,7 +84,7 @@ def read_data(pred_json: Path, label_json: Path): return np.array(prediction), np.array(label) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: if cfg.cpc1_track == "open": track = "_indep" diff --git a/recipes/cpc1/e029_sheffield/infer.py b/recipes/cpc1/e029_sheffield/infer.py index ab97fb4d3..45d6e37b8 100644 --- a/recipes/cpc1/e029_sheffield/infer.py +++ b/recipes/cpc1/e029_sheffield/infer.py @@ -140,7 +140,7 @@ def compute_uncertainty(left_proc_path, asr_model, bos_index, _tokenizer): return conf, neg_ent -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: if cfg.cpc1_track == "open": track = "_indep" diff --git a/recipes/cpc1/e029_sheffield/prepare_data.py b/recipes/cpc1/e029_sheffield/prepare_data.py index 8acfdf55f..5dc3693da 100644 --- a/recipes/cpc1/e029_sheffield/prepare_data.py +++ b/recipes/cpc1/e029_sheffield/prepare_data.py @@ -162,8 +162,9 @@ def generate_data_split( if if_msbg: wav_file = orig_signal_folder / f"{snt_id}_HL-output.wav" elif if_ref: - wav_file = orig_signal_folder / ( - f"../../scenes/{snt_id.split('_')[0]}_target_anechoic.wav" + wav_file = ( + orig_signal_folder + / f"../../scenes/{snt_id.split('_')[0]}_target_anechoic.wav" ) else: wav_file = orig_signal_folder / f"{snt_id}.wav" @@ -293,7 +294,7 @@ def run_signal_generation_test(cfg, track): ) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: if cfg.cpc1_track == "open": track = "_indep" diff --git a/recipes/cpc1/e032_sheffield/config.yaml b/recipes/cpc1/e032_sheffield/config.yaml index 72deff4cd..d1f27e492 100644 --- a/recipes/cpc1/e032_sheffield/config.yaml +++ b/recipes/cpc1/e032_sheffield/config.yaml @@ -19,3 +19,5 @@ hydra: output_subdir: Null run: dir: . + job: + chdir: True diff --git a/recipes/cpc1/e032_sheffield/evaluate.py b/recipes/cpc1/e032_sheffield/evaluate.py index 7edbe944f..1542773bd 100644 --- a/recipes/cpc1/e032_sheffield/evaluate.py +++ b/recipes/cpc1/e032_sheffield/evaluate.py @@ -85,7 +85,7 @@ def read_data(pred_json: Path, label_json: Path): return np.array(prediction), np.array(label) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: if cfg.cpc1_track == "open": track = "_indep" diff --git a/recipes/cpc1/e032_sheffield/infer.py b/recipes/cpc1/e032_sheffield/infer.py index 9a8270c63..55511b824 100644 --- a/recipes/cpc1/e032_sheffield/infer.py +++ b/recipes/cpc1/e032_sheffield/infer.py @@ -219,7 +219,7 @@ def compute_similarity(left_proc_path, wrd, asr_model, bos_index, tokenizer): return enc_similarity[0].numpy(), dec_similarity[0].numpy() -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: if cfg.cpc1_track == "open": track = "_indep" diff --git a/recipes/cpc1/e032_sheffield/prepare_data.py b/recipes/cpc1/e032_sheffield/prepare_data.py index 11fb63b48..1ec7a83b5 100644 --- a/recipes/cpc1/e032_sheffield/prepare_data.py +++ b/recipes/cpc1/e032_sheffield/prepare_data.py @@ -298,7 +298,7 @@ def run_signal_generation_test(cfg, track): ) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run(cfg: DictConfig) -> None: if cfg.cpc1_track == "open": track = "_indep" diff --git a/recipes/cpc2/baseline/compute_haspi.py b/recipes/cpc2/baseline/compute_haspi.py index 7f4a312a9..8d530ae20 100644 --- a/recipes/cpc2/baseline/compute_haspi.py +++ b/recipes/cpc2/baseline/compute_haspi.py @@ -76,7 +76,7 @@ def compute_haspi_for_signal(signal_name: str, path: dict) -> float: # pylint: disable = no-value-for-parameter -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run_calculate_haspi(cfg: DictConfig) -> None: """Run the HASPI score computation.""" # Load the set of signal for which we need to compute scores diff --git a/recipes/cpc2/baseline/config.yaml b/recipes/cpc2/baseline/config.yaml index c444436b6..ec8df013e 100644 --- a/recipes/cpc2/baseline/config.yaml +++ b/recipes/cpc2/baseline/config.yaml @@ -14,3 +14,5 @@ compute_haspi: hydra: run: dir: exp + job: + chdir: True diff --git a/recipes/cpc2/baseline/predict.py b/recipes/cpc2/baseline/predict.py index f94a66707..bfcdfa886 100644 --- a/recipes/cpc2/baseline/predict.py +++ b/recipes/cpc2/baseline/predict.py @@ -71,7 +71,7 @@ def make_disjoint_train_set( # pylint: disable = no-value-for-parameter -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def predict(cfg: DictConfig): """Predict intelligibility from HASPI scores.""" diff --git a/recipes/icassp_2023/baseline/config.yaml b/recipes/icassp_2023/baseline/config.yaml index e1a87ef70..9bc728ec3 100644 --- a/recipes/icassp_2023/baseline/config.yaml +++ b/recipes/icassp_2023/baseline/config.yaml @@ -27,3 +27,5 @@ evaluate: hydra: run: dir: ${path.exp_folder} + job: + chdir: True diff --git a/recipes/icassp_2023/baseline/enhance.py b/recipes/icassp_2023/baseline/enhance.py index 40a353f29..3be7dee1c 100644 --- a/recipes/icassp_2023/baseline/enhance.py +++ b/recipes/icassp_2023/baseline/enhance.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def enhance(cfg: DictConfig) -> None: """Run the dummy enhancement.""" diff --git a/recipes/icassp_2023/baseline/evaluate.py b/recipes/icassp_2023/baseline/evaluate.py index 5c42c3715..eceecc6b0 100644 --- a/recipes/icassp_2023/baseline/evaluate.py +++ b/recipes/icassp_2023/baseline/evaluate.py @@ -94,7 +94,7 @@ def make_scene_listener_list(scenes_listeners, small_test=False): return scene_listener_pairs -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def run_calculate_si(cfg: DictConfig) -> None: """Evaluate the enhanced signals using a combination of HASPI and HASQI metrics""" diff --git a/recipes/icassp_2023/baseline/report_score.py b/recipes/icassp_2023/baseline/report_score.py index 698dba877..4ffd258b2 100644 --- a/recipes/icassp_2023/baseline/report_score.py +++ b/recipes/icassp_2023/baseline/report_score.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="config", version_base=None) def report_score(cfg: DictConfig) -> None: """Run the dummy enhancement.""" diff --git a/tests/conftest.py b/tests/conftest.py index 61bd95dc6..e070ff6ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,6 +37,18 @@ def _random_matrix(seed: int | None = None, size=(100, 100)) -> np.ndarray: return _random_matrix +@pytest.fixture +def abs_tolerance(): + """Fixture for absolute tolerance value.""" + return 1e-7 + + +@pytest.fixture +def rel_tolerance(): + """Fixture for relative tolerance value.""" + return 1e-7 + + def pytest_configure() -> None: """Configure custom variables for pytest. diff --git a/tests/enhancer/multiband_compressor/test_multiband_compresor.py b/tests/enhancer/multiband_compressor/test_multiband_compresor.py index 294ff903b..be323ac82 100644 --- a/tests/enhancer/multiband_compressor/test_multiband_compresor.py +++ b/tests/enhancer/multiband_compressor/test_multiband_compresor.py @@ -22,7 +22,7 @@ def test_compressor_initialization(default_compressor): assert default_compressor.sample_rate == 44100.0 -def test_compressor_signal_processing(default_compressor): +def test_compressor_signal_processing(default_compressor, rel_tolerance, abs_tolerance): """Test the signal processing of the Compressor class.""" input_signal = np.array([[1, 2, 3, 4, 5]]) output_signal = default_compressor(input_signal) @@ -30,7 +30,7 @@ def test_compressor_signal_processing(default_compressor): assert isinstance(output_signal, np.ndarray) assert len(output_signal) == len(input_signal) assert np.sum(output_signal) == pytest.approx( - 15, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + 15, rel=rel_tolerance, abs=abs_tolerance ) @@ -49,7 +49,9 @@ def test_multiband_compressor_onefreq(): assert multiband_compressor.xover_freqs.shape[0] == 1 -def test_multiband_compressor_initialization(default_multiband_compressor): +def test_multiband_compressor_initialization( + default_multiband_compressor, rel_tolerance, abs_tolerance +): """Test the initialization of the MultibandCompressor class.""" assert np.sum(default_multiband_compressor.xover_freqs) == pytest.approx( np.sum( @@ -64,14 +66,14 @@ def test_multiband_compressor_initialization(default_multiband_compressor): ) * np.sqrt(2) ), - rel=pytest.rel_tolerance, - abs=pytest.abs_tolerance, + rel=rel_tolerance, + abs=abs_tolerance, ) assert np.sum(default_multiband_compressor.xover_freqs) == pytest.approx( np.sum(np.array(default_multiband_compressor.xover_freqs)), - rel=pytest.rel_tolerance, - abs=pytest.abs_tolerance, + rel=rel_tolerance, + abs=abs_tolerance, ) assert default_multiband_compressor.sample_rate == 44100.0 assert default_multiband_compressor.num_compressors == 6 @@ -91,7 +93,7 @@ def test_multiband_compressor_set_compressors(default_multiband_compressor): assert len(default_multiband_compressor.compressor) == 6 -def test_multiband_compressor_call(default_multiband_compressor): +def test_multiband_compressor_call(default_multiband_compressor, rel_tolerance): """Test the __call__ method of the MultibandCompressor class.""" np.random.seed(0) signal = np.random.rand(1000) @@ -106,7 +108,7 @@ def test_multiband_compressor_call(default_multiband_compressor): assert bands.shape[0] == 6 assert np.sum(compressed_signal) == pytest.approx( - 441.09490986, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + 441.09490986, rel=rel_tolerance, abs=0.0005 ) diff --git a/tests/recipes/cpc2/baseline/test_evaluate_cpc2.py b/tests/recipes/cpc2/baseline/test_evaluate_cpc2.py index 4134e7554..f03acf899 100644 --- a/tests/recipes/cpc2/baseline/test_evaluate_cpc2.py +++ b/tests/recipes/cpc2/baseline/test_evaluate_cpc2.py @@ -23,10 +23,10 @@ @pytest.mark.parametrize( "x, y, expected", [([1, 2, 3], [1, 2, 3], 0), ([0], [1], 1), ([1, 1, 1], [1], 0)] ) -def test_rmse_score_ok(x, y, expected): +def test_rmse_score_ok(x, y, expected, rel_tolerance, abs_tolerance): """Test the function rmse_score valid inputs""" assert rmse_score(np.array(x), np.array(y)) == pytest.approx( - expected, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + expected, rel=rel_tolerance, abs=abs_tolerance ) @@ -47,10 +47,10 @@ def test_rmse_score_error(x, y, expected): @pytest.mark.parametrize( "x, y, expected", [([1, 2, 3], [1, 2, 3], 1), ([1, -1], [-1, 1], -1)] ) -def test_ncc_score_ok(x, y, expected): +def test_ncc_score_ok(x, y, expected, rel_tolerance, abs_tolerance): """Test the function ncc_score valid inputs""" assert ncc_score(np.array(x), np.array(y)) == pytest.approx( - expected, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + expected, rel=rel_tolerance, abs=abs_tolerance ) @@ -67,10 +67,10 @@ def test_ncc_score_error(x, y, expected): @pytest.mark.parametrize( "x, y, expected", [([1, 2, 3], [1, 2, 3], 1), ([1, -1], [-1, 1], -1)] ) -def test_kt_score_ok(x, y, expected): +def test_kt_score_ok(x, y, expected, rel_tolerance, abs_tolerance): """Test the function kt_score valid inputs""" assert kt_score(np.array(x), np.array(y)) == pytest.approx( - expected, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + expected, rel=rel_tolerance, abs=abs_tolerance ) @@ -95,10 +95,10 @@ def test_kt_score_error(x, y, expected): ([1, 2, 3], [11, 12, 13], 0), ], ) -def test_std_err_ok(x, y, expected): +def test_std_err_ok(x, y, expected, rel_tolerance, abs_tolerance): """Test the function std_err valid inputs""" assert std_err(np.array(x), np.array(y)) == pytest.approx( - expected, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + expected, rel=rel_tolerance, abs=abs_tolerance ) @@ -161,10 +161,7 @@ def test_evaluate(hydra_cfg, capsys): prediction_file = "CEC1.train.sample.predict.csv" score_file = "CEC1.train.sample.evaluate.jsonl" - expected_output = ( - "{'RMSE': 30.256228825071368, 'Std': 4.209845712831399, " - "'NCC': nan, 'KT': nan}\n" - ) + test_data = [ {"signal": "S08547_L0001_E001", "predicted": 0.8}, {"signal": "S08564_L0001_E001", "predicted": 0.8}, @@ -182,9 +179,6 @@ def test_evaluate(hydra_cfg, capsys): warnings.simplefilter("ignore", category=RuntimeWarning) evaluate(hydra_cfg) - captured = capsys.readouterr() - assert captured.out == expected_output - # Check scores scores = read_jsonl(score_file) assert scores[0]["RMSE"] == pytest.approx(30.2562, abs=1e-4) diff --git a/tests/regression/_regtest_outputs/test_engine_losses.test_snr_loss.out b/tests/regression/_regtest_outputs/test_engine_losses.test_snr_loss.out index 43de6f094..0a7e244fb 100644 --- a/tests/regression/_regtest_outputs/test_engine_losses.test_snr_loss.out +++ b/tests/regression/_regtest_outputs/test_engine_losses.test_snr_loss.out @@ -1 +1 @@ -SNR loss 2.907641 +SNR loss 2.90764 diff --git a/tests/regression/_regtest_outputs/test_full_CEC2_pipeline.test_full_cec2_pipeline.out b/tests/regression/_regtest_outputs/test_full_CEC2_pipeline.test_full_cec2_pipeline.out index 41dd0752f..9bf1d36ef 100644 --- a/tests/regression/_regtest_outputs/test_full_CEC2_pipeline.test_full_cec2_pipeline.out +++ b/tests/regression/_regtest_outputs/test_full_CEC2_pipeline.test_full_cec2_pipeline.out @@ -1 +1 @@ -Enhanced audio HASPI score is 0.7491529 +Enhanced audio HASPI score is 0.74915 diff --git a/tests/regression/_regtest_outputs/test_predictors.test_torch_msbg_stoi_non_xeon_e5_2673_cpu.out b/tests/regression/_regtest_outputs/test_predictors.test_torch_msbg_stoi_non_xeon_e5_2673_cpu.out index 751fa7965..df80fc506 100644 --- a/tests/regression/_regtest_outputs/test_predictors.test_torch_msbg_stoi_non_xeon_e5_2673_cpu.out +++ b/tests/regression/_regtest_outputs/test_predictors.test_torch_msbg_stoi_non_xeon_e5_2673_cpu.out @@ -1 +1 @@ -Torch MSBG STOILoss -0.46198, ESTOILoss -0.33000 +Torch MSBG STOILoss -0.46198, ESTOILoss -0.3300 diff --git a/tests/regression/test_engine_losses.py b/tests/regression/test_engine_losses.py index 8ee867cd4..6ca271e15 100644 --- a/tests/regression/test_engine_losses.py +++ b/tests/regression/test_engine_losses.py @@ -24,7 +24,7 @@ def test_snr_loss(regtest): y = torch.randn(10, 1000) loss = snr_loss.forward(x, y) - regtest.write(f"SNR loss {loss:0.6f}\n") + regtest.write(f"SNR loss {loss:0.5f}\n") def test_stoi_loss(regtest): diff --git a/tests/regression/test_full_CEC2_pipeline.py b/tests/regression/test_full_CEC2_pipeline.py index ba30e9880..329f6f53b 100644 --- a/tests/regression/test_full_CEC2_pipeline.py +++ b/tests/regression/test_full_CEC2_pipeline.py @@ -175,6 +175,6 @@ def test_full_cec2_pipeline( listener=listener, ) - regtest.write(f"Enhanced audio HASPI score is {sii_enhanced:0.7f}\n") + regtest.write(f"Enhanced audio HASPI score is {sii_enhanced:0.5f}\n") # Enhanced audio HASPI score is 0.2994066 diff --git a/tests/regression/test_predictors.py b/tests/regression/test_predictors.py index fa06db6b4..5971e250b 100644 --- a/tests/regression/test_predictors.py +++ b/tests/regression/test_predictors.py @@ -32,7 +32,7 @@ def test_torch_msbg_stoi_non_xeon_e5_2673_cpu(regtest): estoi_loss = estoi_loss.forward(x.cpu(), y.cpu()).mean() regtest.write( - f"Torch MSBG STOILoss {stoi_loss:0.5f}, ESTOILoss {estoi_loss:0.5f}\n" + f"Torch MSBG STOILoss {stoi_loss:0.5f}, ESTOILoss {estoi_loss:0.4f}\n" )