From 4a9588bdd5e9160f5043d6653a673c3b6f033e5f Mon Sep 17 00:00:00 2001 From: Sung-Lin Yeh Date: Tue, 19 Sep 2023 17:41:47 -0700 Subject: [PATCH 1/2] set temperature to 0 to make results reproducible --- simuleval/evaluator/scorers/quality_scorer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simuleval/evaluator/scorers/quality_scorer.py b/simuleval/evaluator/scorers/quality_scorer.py index a5a42ed7..134abf88 100644 --- a/simuleval/evaluator/scorers/quality_scorer.py +++ b/simuleval/evaluator/scorers/quality_scorer.py @@ -313,7 +313,7 @@ def asr_transcribe(self, instances): wav_path = wav_dir / f"{index}_pred.wav" if wav_path.exists(): result = model.transcribe( - wav_path.as_posix(), language=self.target_lang + wav_path.as_posix(), language=self.target_lang, temperature=0.0 ) text = result["text"] assert type(text) == str From bc3e705323be54d07f6a54edd018b30f1b1680c8 Mon Sep 17 00:00:00 2001 From: 30stomercury Date: Wed, 11 Oct 2023 14:26:02 -0700 Subject: [PATCH 2/2] add temperature to args --- simuleval/evaluator/scorers/quality_scorer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/simuleval/evaluator/scorers/quality_scorer.py b/simuleval/evaluator/scorers/quality_scorer.py index 134abf88..cdb0ed42 100644 --- a/simuleval/evaluator/scorers/quality_scorer.py +++ b/simuleval/evaluator/scorers/quality_scorer.py @@ -266,6 +266,7 @@ def __init__( tokenizer: str = "13a", target_lang: str = "en", model_size: str = "base", + temperature: float = 0.0, lowercase: bool = False, remove_punctuations: bool = False, ) -> None: @@ -274,6 +275,7 @@ def __init__( self.tokenizer = tokenizer self.target_lang = target_lang self.model_size = model_size + self.temperature = temperature self.lowercase = lowercase self.remove_punctuations = remove_punctuations @@ -297,6 +299,7 @@ def asr_transcribe(self, instances): self.logger.info(f"tokenizer = {self.tokenizer}") self.logger.info(f"target_lang = {self.target_lang}") self.logger.info(f"model_size = {self.model_size}") + self.logger.info(f"temperature = {self.temperature}") self.logger.info(f"lowercase = {self.lowercase}") self.logger.info(f"remove_punctuations = {self.remove_punctuations}") try: @@ -313,7 +316,7 @@ def asr_transcribe(self, instances): wav_path = wav_dir / f"{index}_pred.wav" if wav_path.exists(): result = model.transcribe( - wav_path.as_posix(), language=self.target_lang, temperature=0.0 + wav_path.as_posix(), language=self.target_lang, temperature=self.temperature ) text = result["text"] assert type(text) == str @@ -348,6 +351,12 @@ def add_args(parser): default="large", help="The size of whisper asr model", ) + parser.add_argument( + "--whisper-model-temperature", + type=float, + default=0.0, + help="If temperature > 0.0, the decoding will perform sampling", + ) parser.add_argument( "--transcript-lowercase", action="store_true", @@ -365,6 +374,7 @@ def from_args(cls, args): args.sacrebleu_tokenizer, args.target_speech_lang, args.whisper_model_size, + args.whisper_model_temperature, args.transcript_lowercase, args.transcript_non_punctuation, )