Skip to content

Commit

Permalink
Cleaned comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MaanavD committed Jul 10, 2023
1 parent d0ac5ca commit 7b332c9
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions torchbenchmark/models/whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Model(BenchmarkModel):

def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
# Failing on cpu and batch sizes that are too large
if self.device == 'cpu':
return NotImplementedError("CPU test too slow - skipping.")
if batch_size > 72:
Expand All @@ -26,6 +27,7 @@ def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
"""
return NotImplementedError(error_msg)
self.model = load_model("medium", self.device, "./.data", in_memory=True)
# Importing dataset and preprocessing
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
mels=[]
for i in range(self.batch_size):
Expand All @@ -48,10 +50,5 @@ def train(self):
return NotImplementedError(error_msg)

def eval(self):
# self.model.eval()
with torch.no_grad():
return self.model.decode(self.example_inputs, self.model_args)
# return [self.model.transcribe(inp) for inp in self.example_inputs]



return self.model.decode(self.example_inputs, self.model_args)

0 comments on commit 7b332c9

Please sign in to comment.