Skip to content

Commit

Permalink
set temperature to 0 to make results reproducible (#77)
Browse files Browse the repository at this point in the history
* set temperature to 0 to make results reproducible

* add temperature to args

---------

Co-authored-by: Sung-Lin Yeh <slyeh@devfair0457.h2.fair>
  • Loading branch information
30stomercury and Sung-Lin Yeh authored Oct 11, 2023
1 parent 6944d63 commit 411a73d
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion simuleval/evaluator/scorers/quality_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __init__(
tokenizer: str = "13a",
target_lang: str = "en",
model_size: str = "base",
temperature: float = 0.0,
lowercase: bool = False,
remove_punctuations: bool = False,
) -> None:
Expand All @@ -274,6 +275,7 @@ def __init__(
self.tokenizer = tokenizer
self.target_lang = target_lang
self.model_size = model_size
self.temperature = temperature
self.lowercase = lowercase
self.remove_punctuations = remove_punctuations

Expand All @@ -297,6 +299,7 @@ def asr_transcribe(self, instances):
self.logger.info(f"tokenizer = {self.tokenizer}")
self.logger.info(f"target_lang = {self.target_lang}")
self.logger.info(f"model_size = {self.model_size}")
self.logger.info(f"temperature = {self.temperature}")
self.logger.info(f"lowercase = {self.lowercase}")
self.logger.info(f"remove_punctuations = {self.remove_punctuations}")
try:
Expand All @@ -313,7 +316,7 @@ def asr_transcribe(self, instances):
wav_path = wav_dir / f"{index}_pred.wav"
if wav_path.exists():
result = model.transcribe(
wav_path.as_posix(), language=self.target_lang
wav_path.as_posix(), language=self.target_lang, temperature=self.temperature
)
text = result["text"]
assert type(text) == str
Expand Down Expand Up @@ -348,6 +351,12 @@ def add_args(parser):
default="large",
help="The size of whisper asr model",
)
parser.add_argument(
"--whisper-model-temperature",
type=float,
default=0.0,
help="If temperature > 0.0, the decoding will perform sampling",
)
parser.add_argument(
"--transcript-lowercase",
action="store_true",
Expand All @@ -365,6 +374,7 @@ def from_args(cls, args):
args.sacrebleu_tokenizer,
args.target_speech_lang,
args.whisper_model_size,
args.whisper_model_temperature,
args.transcript_lowercase,
args.transcript_non_punctuation,
)

0 comments on commit 411a73d

Please sign in to comment.