Skip to content

Commit

Permalink
Merge pull request #421 from claritychallenge/alt-eval-cad2-task1
Browse files Browse the repository at this point in the history
set temperature to whisper
  • Loading branch information
groadabike authored Oct 18, 2024
2 parents 85ed818 + ccf73a7 commit ee84d38
Showing 1 changed file with 48 additions and 16 deletions.
64 changes: 48 additions & 16 deletions recipes/cad2/task1/baseline/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def set_song_seed(song: str) -> None:
song_md5 = int(song_encoded, 16) % (10**8)
np.random.seed(song_md5)

torch.manual_seed(song_md5)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(song_md5)


def make_scene_listener_list(scenes_listeners: dict, small_test: bool = False) -> list:
"""Make the list of scene-listener pairing to process
Expand Down Expand Up @@ -109,12 +113,14 @@ def compute_intelligibility(
enhanced_left = ear.process(enhanced_signal[:, 0])[0]
left_path = Path(f"{path_intermediate.as_posix()}_left.flac")
save_flac_signal(
enhanced_signal,
enhanced_left,
left_path,
44100,
sample_rate,
)
hypothesis = scorer.transcribe(left_path.as_posix(), fp16=False)["text"]
hypothesis = scorer.transcribe(left_path.as_posix(), fp16=False, temperature=0.0)[
"text"
]
lyrics["hypothesis_left"] = hypothesis

left_results = compute_metrics(
Expand All @@ -126,18 +132,19 @@ def compute_intelligibility(
enhanced_right = ear.process(enhanced_signal[:, 1])[0]
right_path = Path(f"{path_intermediate.as_posix()}_right.flac")
save_flac_signal(
enhanced_signal,
enhanced_right,
right_path,
44100,
sample_rate,
)
hypothesis = scorer.transcribe(right_path.as_posix(), fp16=False)["text"]
hypothesis = scorer.transcribe(right_path.as_posix(), fp16=False, temperature=0.0)[
"text"
]
lyrics["hypothesis_right"] = hypothesis

right_results = compute_metrics(
[reference], [hypothesis], languages="en", include_other=False
)


# Compute the average score for both ears
total_words = (
Expand Down Expand Up @@ -169,9 +176,23 @@ def compute_quality(
reference_signal: np.ndarray,
enhanced_signal: np.ndarray,
listener: Listener,
config: DictConfig,
reference_sample_rate: int,
enhanced_sample_rate: int,
HAAQI_sample_rate: int,
) -> tuple[float, float]:
"""Compute the HAAQI score for the left and right channels"""
"""Compute the HAAQI score for the left and right channels
Args:
reference_signal: The reference signal
enhanced_signal: The enhanced signal
listener: The listener
reference_sample_rate: The sample rate of the reference signal
enhanced_sample_rate: The sample rate of the enhanced signal
HAAQI_sample_rate: The sample rate for the HAAQI computation
Returns:
The HAAQI score for the left and right channels
"""
scores = []

for channel in range(2):
Expand All @@ -181,16 +202,16 @@ def compute_quality(
s = compute_haaqi(
processed_signal=resample(
enhanced_signal[:, channel],
config.remix_sample_rate,
config.HAAQI_sample_rate,
enhanced_sample_rate,
HAAQI_sample_rate,
),
reference_signal=resample(
reference_signal[:, channel],
config.input_sample_rate,
config.HAAQI_sample_rate,
reference_sample_rate,
HAAQI_sample_rate,
),
processed_sample_rate=config.HAAQI_sample_rate,
reference_sample_rate=config.HAAQI_sample_rate,
processed_sample_rate=HAAQI_sample_rate,
reference_sample_rate=HAAQI_sample_rate,
audiogram=audiogram,
equalisation=2,
level1=65 - 20 * np.log10(compute_rms(reference_signal[:, channel])),
Expand Down Expand Up @@ -318,6 +339,11 @@ def run_compute_scores(config: DictConfig) -> None:
sample_rate=config.input_sample_rate,
)

# Configure backend for determinism
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

# Load the Whisper model
intelligibility_scorer = whisper.load_model(config.evaluate.whisper_version)

# Loop over the scene-listener pairs
Expand Down Expand Up @@ -395,7 +421,15 @@ def run_compute_scores(config: DictConfig) -> None:
# COMPUTE SCORES

# Compute the HAAQI and Whisper scores
haaqi_scores = compute_quality(reference, enhanced_signal, listener, config)
haaqi_scores = compute_quality(
reference_signal=reference,
enhanced_signal=enhanced_signal,
listener=listener,
reference_sample_rate=config.input_sample_rate,
enhanced_sample_rate=config.remix_sample_rate,
HAAQI_sample_rate=config.HAAQI_sample_rate,
)

whisper_left, whisper_right, lyrics_text = compute_intelligibility(
enhanced_signal=enhanced_signal,
segment_metadata=songs[scene["segment_id"]],
Expand Down Expand Up @@ -425,8 +459,6 @@ def run_compute_scores(config: DictConfig) -> None:
"whisper_be": max_whisper,
"alpha": alpha,
"score": alpha * max_whisper + (1 - alpha) * np.mean(haaqi_scores),


}
)

Expand Down

0 comments on commit ee84d38

Please sign in to comment.