Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Intelligibility metric for CAD2 Task1 #419

Merged
merged 44 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9bb51c8
[pre-commit.ci] pre-commit-autoupdate
pre-commit-ci[bot] Jul 29, 2024
2cc4c09
Update README.md
groadabike Jul 30, 2024
d5d09a1
Merge pull request #400 from claritychallenge/pre-commit-ci-update-co…
groadabike Jul 30, 2024
8575c4b
[pre-commit.ci] pre-commit-autoupdate
pre-commit-ci[bot] Aug 12, 2024
9fcdf8b
Merge pull request #403 from claritychallenge/pre-commit-ci-update-co…
groadabike Aug 15, 2024
76e732b
Add explicit hydra.job.chdir=True needed for version 1.2
groadabike Aug 21, 2024
d0fa83d
add version_base=None in @hydra.main(), needed for version=1.2
groadabike Aug 21, 2024
a4cfc57
[pre-commit.ci] Fixing issues with pre-commit
pre-commit-ci[bot] Aug 21, 2024
ddb3b22
Place imports at top of cell
groadabike Aug 21, 2024
b383605
Merge pull request #408 from claritychallenge/407-pre-commit-raises-s…
jonbarker68 Sep 9, 2024
c191b4a
Merge pull request #406 from claritychallenge/grd-hydra-set-version-b…
jonbarker68 Sep 9, 2024
783f34f
Merge pull request #402 from claritychallenge/groadabike-patch-1
jonbarker68 Sep 9, 2024
1f3116b
[pre-commit.ci] pre-commit-autoupdate
pre-commit-ci[bot] Sep 9, 2024
11c7769
sample rate in config
groadabike Sep 10, 2024
27d14f2
correct use of whisper
groadabike Sep 10, 2024
73ea121
config
groadabike Sep 10, 2024
7e9a5bd
fix ruff reported errors
groadabike Sep 10, 2024
a07895b
[pre-commit.ci] Fixing issues with pre-commit
pre-commit-ci[bot] Sep 10, 2024
c8d4a22
Merge pull request #412 from claritychallenge/411-pre-commit-errors-w…
jonbarker68 Sep 10, 2024
2be5175
Merge remote-tracking branch 'origin/main' into pre-commit-ci-update-…
jonbarker68 Sep 10, 2024
dade982
Merge pull request #404 from claritychallenge/pre-commit-ci-update-co…
jonbarker68 Sep 10, 2024
c46ad4a
Increment min support python to 3.9
jonbarker68 Sep 11, 2024
c008cf4
Reduced sensitivity of a small number of tests that were failing on t…
jonbarker68 Sep 11, 2024
6c124a1
Fixed to use new scipy window module
jonbarker68 Sep 11, 2024
8b5d63c
Updated min versions in project dependencies
jonbarker68 Sep 11, 2024
d0038b4
Fixed a brittle test that was broken by numpy upgrade
jonbarker68 Sep 11, 2024
666bb10
second attempt to fix a broken test
jonbarker68 Sep 11, 2024
a526a3a
Merge pull request #410 from claritychallenge/cad2-fix-in-main
jonbarker68 Sep 11, 2024
36cc73e
[pre-commit.ci] pre-commit-autoupdate
pre-commit-ci[bot] Sep 16, 2024
649c76c
Merge pull request #413 from claritychallenge/jpb/support-for-python-312
groadabike Sep 17, 2024
1421a8f
Update run_tests.yml
groadabike Sep 20, 2024
8b2be0d
Merge pull request #414 from claritychallenge/pre-commit-ci-update-co…
groadabike Sep 20, 2024
6b66e86
[pre-commit.ci] pre-commit-autoupdate
pre-commit-ci[bot] Oct 8, 2024
745fef9
Merge pull request #415 from claritychallenge/pre-commit-ci-update-co…
groadabike Oct 11, 2024
46af3a4
[pre-commit.ci] pre-commit-autoupdate
pre-commit-ci[bot] Oct 14, 2024
12f0a5f
Merge pull request #417 from claritychallenge/pre-commit-ci-update-co…
groadabike Oct 15, 2024
e019ef9
add alt_eval into requirements
groadabike Oct 15, 2024
f5543bc
replace jiwer for alt-eval
groadabike Oct 15, 2024
820f03c
add verbose into msgb
groadabike Oct 15, 2024
b7cadfb
correct discrepancy in branch
groadabike Oct 15, 2024
adf64af
correct discrepancy
groadabike Oct 15, 2024
13978a3
Merge branch 'main' into alt-eval-cad2-task1
groadabike Oct 15, 2024
9f4df5b
[pre-commit.ci] Fixing issues with pre-commit
pre-commit-ci[bot] Oct 15, 2024
02c0684
Merge branch 'v0.6' into alt-eval-cad2-task1
groadabike Oct 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v3
- name: Set up Python
Expand Down
23 changes: 11 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0 # Use the ref you want to point at
rev: v5.0.0 # Use the ref you want to point at
hooks:
- id: trailing-whitespace
types: [file, text]
Expand All @@ -20,25 +20,25 @@ repos:
- id: check-toml

- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
rev: v3.18.0
hooks:
- id: pyupgrade
args: [--py38-plus]
args: [--py39-plus]

- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.10.0
hooks:
- id: black
types: [python]
additional_dependencies: ["click==8.0.4"]

- repo: https://github.com/DavidAnson/markdownlint-cli2
rev: v0.13.0
rev: v0.14.0
hooks:
- id: markdownlint-cli2

- repo: https://github.com/pycqa/flake8.git
rev: 7.1.0
rev: 7.1.1
hooks:
- id: flake8
additional_dependencies: [flake8-print, Flake8-pyproject]
Expand All @@ -52,7 +52,7 @@ repos:
args: ["--profile", "black"]

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.8.5
rev: 1.8.7
hooks:
- id: nbqa-black
- id: nbqa-flake8
Expand All @@ -65,24 +65,23 @@ repos:
- id: nbstripout

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.1
rev: v1.11.2
hooks:
- id: mypy
args:
- --explicit-package-bases
additional_dependencies:
- 'types-PyYAML'

- "types-PyYAML"

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: "v0.5.1"
rev: "v0.6.9"
hooks:
- id: ruff

# Serious pylint errors that will be enforced by CI
- repo: https://github.com/pycqa/pylint
rev: v3.2.5
rev: v3.3.1
hooks:
- id: pylint
args:
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/claritychallenge/clarity/main.svg)](https://results.pre-commit.ci/latest/github/claritychallenge/clarity/main)
[![Downloads](https://pepy.tech/badge/pyclarity)](https://pepy.tech/project/pyclarity)

[![PyPI](https://img.shields.io/static/v1?label=CAD2%20Challenge%20-%20pypi&message=v0.6.0&color=orange)](https://pypi.org/project/pyclarity/0.6.0/)
[![PyPI](https://img.shields.io/static/v1?label=CEC3%20Challenge%20-%20pypi&message=v0.5.0&color=orange)](https://pypi.org/project/pyclarity/0.5.0/)
[![PyPI](https://img.shields.io/static/v1?label=ICASSP%202024%20Cadenza%20Challenge%20-%20pypi&message=v0.4.1&color=orange)](https://pypi.org/project/pyclarity/0.4.1/)
[![PyPI](https://img.shields.io/static/v1?label=CAD1%20and%20CPC2%20Challenges%20-%20pypi&message=v0.3.4&color=orange)](https://pypi.org/project/pyclarity/0.3.4/)
Expand All @@ -37,14 +38,12 @@ In this repository, you will find code to support all Clarity and Cadenza Challe

## Current Events

- The 2nd Cadenza Challenge is now open :fire::fire:
- Visit the [cadenza website](https://cadenzachallenge.org/docs/cadenza2/intro) for more details.
- Join the [Cadenza Challenge Group](https://groups.google.com/g/cadenza-challenge) to keep up-to-date on developments.
- The 3rd Clarity Enhancement Challenge is now open. :fire::fire:
- Visit the [challenge website](https://claritychallenge.org/docs/cec3/cec3_intro) for more details.
- Join the [Clarity Challenge Group](https://groups.google.com/g/clarity-challenge) to keep up-to-date on developments.
- The ICASSP 2024 Cadenza Challenge (CAD_ICASSP_2024) will be presented at ICASSP 2024.
- Join the [Cadenza Challenge Group](https://groups.google.com/g/cadenza-challenge) to keep up-to-date on developments.
- Visit the Cadenenza Challenge [website](https://cadenzachallenge.org/) for more details.
- The first Cadenza Challenge (CAD1) is closed.
- Subjective Evaluation is underway. :new:
- The 2nd Clarity Prediction Challenge (CPC2) is now closed.
- The 4th Clarity Workshop will be held as a satellite event of Interspeech 2023. For details visit the [workshop website](https://claritychallenge.org/clarity2023-workshop/).

Expand Down Expand Up @@ -90,6 +89,7 @@ pip install -e git+https://github.com/claritychallenge/clarity.git@main

Current challenge

- [The 2nd Cadenza Challege](./recipes/cad2)
- [The 3rd Clarity Enhancement Challenge](./recipes/cec3)

Previous challenges
Expand Down
10 changes: 8 additions & 2 deletions clarity/evaluator/msbg/cochlea.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,11 @@ class Cochlea:
"""

def __init__(
self, audiogram: Audiogram, catch_up_level: float = 105.0, fs: float = 44100.0
self,
audiogram: Audiogram,
catch_up_level: float = 105.0,
fs: float = 44100.0,
verbose=True,
) -> None:
"""Cochlea constructor.

Expand All @@ -233,6 +237,7 @@ def __init__(
catch_up_level (float, optional): loudness catch-up level in dB
Default is 105 dB
fs (float, optional): sampling frequency
verbose (bool, optional): verbose mode. Default is True

"""
self.fs = fs
Expand All @@ -254,7 +259,8 @@ def __init__(
r_lower, r_upper = HL_PARAMS[severity_level]["smear_params"]
self.smearer = Smearer(r_lower, r_upper, fs)

logging.info("Severity level - %s", severity_level)
if verbose:
logging.info("Severity level - %s", severity_level)

def simulate(self, coch_sig: ndarray, equiv_0dB_file_SPL: float) -> ndarray:
"""Pass a signal through the cochlea.
Expand Down
35 changes: 24 additions & 11 deletions clarity/evaluator/msbg/msbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
sample_rate: float = 44100.0,
equiv_0db_spl: float = 100.0,
ahr: float = 20.0,
verbose: bool = True,
) -> None:
"""
Constructor for the Ear class.
Expand All @@ -48,7 +49,9 @@ def __init__(
sample_rate (float): sample frequency.
equiv_0db_spl (): ???
ahr (): ???
verbose (): ???
"""
self.verbose = verbose
self.sample_rate = sample_rate
self.src_correction = self.get_src_correction(src_pos)
self.equiv_0db_spl = equiv_0db_spl
Expand All @@ -62,7 +65,7 @@ def set_audiogram(self, audiogram: Audiogram) -> None:
"Impairment too severe: Suggest you limit audiogram max to"
"80-90 dB HL, otherwise things go wrong/unrealistic."
)
self.cochlea = Cochlea(audiogram=audiogram)
self.cochlea = Cochlea(audiogram=audiogram, verbose=self.verbose)

@staticmethod
def get_src_correction(src_pos: str) -> ndarray:
Expand Down Expand Up @@ -92,6 +95,7 @@ def src_to_cochlea_filt(
src_correction: ndarray,
sample_rate: float,
backward: bool = False,
verbose: bool = True,
) -> ndarray:
"""Simulate middle and outer ear transfer functions.

Expand All @@ -109,12 +113,14 @@ def src_to_cochlea_filt(
or ITU
sample_rate (int): sampling frequency
backward (bool, optional): if true then cochlea to src (default: False)
verbose (bool, optional): print verbose output (default: True)

Returns:
np.ndarray: the processed signal

"""
logging.info("performing outer/middle ear corrections")
if verbose:
logging.info("performing outer/middle ear corrections")

# make sure that response goes only up to sample_frequency/2
nyquist = int(sample_rate / 2.0)
Expand Down Expand Up @@ -204,7 +210,8 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra
)
raise ValueError("Invalid sampling frequency, valid value is 44100")

logging.info("Processing {len(chans)} samples")
if self.verbose:
logging.info("Processing {len(chans)} samples")

# Need to know file RMS, and then call that a certain level in SPL:
# needs some form of pre-measuring.
Expand All @@ -219,7 +226,7 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra

# Measure RMS where 3rd arg is dB_rel_rms (how far below)
calculated_rms, idx, _rel_db_thresh, _active = measure_rms(
signal[0], sample_rate, -12
signal[0], sample_rate, -12, verbose=self.verbose
)

# Rescale input data and check level after rescaling
Expand All @@ -229,11 +236,11 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra
new_rms_db = equiv_0db_spl + 10 * np.log10(
np.mean(np.power(signal[0][idx], 2.0))
)
logging.info(
"Rescaling: "
f"leveldBSPL was {level_db_spl:3.1f} dB SPL, now {new_rms_db:3.1f} dB SPL. "
f" Target SPL is {target_spl:3.1f} dB SPL."
)
if self.verbose:
logging.info(
f"Rescaling: leveldBSPL was {level_db_spl:3.1f} dB SPL, now"
f" {new_rms_db:3.1f} dB SPL. Target SPL is {target_spl:3.1f} dB SPL."
)

# Add calibration signal at target SPL dB
if add_calibration is True:
Expand All @@ -247,11 +254,17 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra
signal = np.concatenate((pre_calibration, signal, post_calibration), axis=1)

# Transform from src pos to cochlea, simulate cochlea, transform back to src pos
signal = Ear.src_to_cochlea_filt(signal, self.src_correction, sample_rate)
signal = Ear.src_to_cochlea_filt(
signal, self.src_correction, sample_rate, verbose=self.verbose
)
if self.cochlea is not None:
signal = np.array([self.cochlea.simulate(x, equiv_0db_spl) for x in signal])
signal = Ear.src_to_cochlea_filt(
signal, self.src_correction, sample_rate, backward=True
signal,
self.src_correction,
sample_rate,
backward=True,
verbose=self.verbose,
)

# Implement low-pass filter at top end of audio range: flat to Cutoff freq,
Expand Down
20 changes: 13 additions & 7 deletions clarity/evaluator/msbg/msbg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import numpy as np
import scipy
import scipy.signal
from numpy import ndarray
from scipy.signal.windows import hamming, kaiser

# measure rms parameters
WIN_SECS: Final = 0.01
Expand Down Expand Up @@ -167,7 +167,7 @@ def firwin2(
order = n_taps - 1

if window_type == "kaiser":
window_shape = scipy.signal.kaiser(n_taps, window_param)
window_shape = kaiser(n_taps, window_param)

if window_shape is None:
filter_coef, _ = fir2(order, frequencies, filter_gains)
Expand Down Expand Up @@ -203,7 +203,7 @@ def fir2(
filter_length += 1

if window_shape is None:
window_shape = scipy.signal.hamming(filter_length)
window_shape = hamming(filter_length)

n_interpolate = (
2 ** np.ceil(math.log(filter_length) / math.log(2.0))
Expand Down Expand Up @@ -358,6 +358,7 @@ def generate_key_percent(
threshold_db: float,
window_length: int,
percent_to_track: float | None = None,
verbose: bool = False,
) -> tuple[ndarray, float]:
"""Generate key percent.
Locates frames above some energy threshold or tracks a certain percentage
Expand All @@ -370,6 +371,7 @@ def generate_key_percent(
window_length (int): length of window in samples.
percent_to_track (float, optional): Track a percentage of frames.
Default is None
verbose (bool, optional): Print verbose output. Default is False.

Raises:
ValueError: percent_to_track is set too high.
Expand All @@ -393,10 +395,11 @@ def generate_key_percent(

expected = threshold_db
# new Dec 2003. Possibly track percentage of frames rather than fixed threshold
if percent_to_track is not None:
logging.info("tracking %s percentage of frames", percent_to_track)
else:
logging.info("tracking fixed threshold")
if verbose:
if percent_to_track is not None:
logging.info("tracking %s percentage of frames", percent_to_track)
else:
logging.info("tracking fixed threshold")

# put floor into histogram distribution
non_zero = np.power(10, (expected - 30) / 10)
Expand Down Expand Up @@ -466,6 +469,7 @@ def measure_rms(
sample_rate: float,
db_rel_rms: float,
percent_to_track: float | None = None,
verbose: bool = False,
) -> tuple[float, ndarray, float, float]:
"""Measure Root Mean Square.

Expand All @@ -481,6 +485,7 @@ def measure_rms(
db_rel_rms (float): threshold for frames to track.
percent_to_track (float, optional): track percentage of frames,
rather than threshold (default: {None})
verbose (bool, optional): Print verbose output. Default is False.
Returns:
(tuple): tuple containing
- rms (float): overall calculated rms (linear)
Expand All @@ -500,6 +505,7 @@ def measure_rms(
key_thr_db,
round(WIN_SECS * sample_rate),
percent_to_track=percent_to_track,
verbose=verbose,
)

idx = key.astype(int) # move into generate_key_percent
Expand Down
3 changes: 1 addition & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import sys
from importlib.metadata import version
from typing import Dict

# -*- coding: utf-8 -*-
#
Expand Down Expand Up @@ -157,7 +156,7 @@

# -- Options for LaTeX output ------------------------------------------------

latex_elements: Dict[str, str] = {
latex_elements: dict[str, str] = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
Expand Down
Loading