From 7b332c9316fe2f50bf48fdcc9a770e83c29f12ea Mon Sep 17 00:00:00 2001 From: MaanavD Date: Mon, 10 Jul 2023 14:43:47 -0500 Subject: [PATCH] Cleaned comments --- torchbenchmark/models/whisper/__init__.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchbenchmark/models/whisper/__init__.py b/torchbenchmark/models/whisper/__init__.py index f84c5603c9..1e53952522 100644 --- a/torchbenchmark/models/whisper/__init__.py +++ b/torchbenchmark/models/whisper/__init__.py @@ -18,6 +18,7 @@ class Model(BenchmarkModel): def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) + # Failing on cpu and batch sizes that are too large if self.device == 'cpu': return NotImplementedError("CPU test too slow - skipping.") if batch_size > 72: @@ -26,6 +27,7 @@ def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): """ return NotImplementedError(error_msg) self.model = load_model("medium", self.device, "./.data", in_memory=True) + # Importing dataset and preprocessing dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") mels=[] for i in range(self.batch_size): @@ -48,10 +50,5 @@ def train(self): return NotImplementedError(error_msg) def eval(self): - # self.model.eval() with torch.no_grad(): - return self.model.decode(self.example_inputs, self.model_args) - # return [self.model.transcribe(inp) for inp in self.example_inputs] - - - \ No newline at end of file + return self.model.decode(self.example_inputs, self.model_args) \ No newline at end of file