Skip to content

Commit

Permalink
Added Whisper from Huggingface. (#1769)
Browse files Browse the repository at this point in the history
Summary:
Instead of making changes, using HF is easier / more maintainable.

Pull Request resolved: #1769

Reviewed By: xuzhao9, cpuhrsch

Differential Revision: D47766556

Pulled By: msaroufim

fbshipit-source-id: 8393776222fc3508bda56c9c71e45d9812e69869
  • Loading branch information
MaanavD authored and xuzhao9 committed Jul 26, 2023
1 parent f8d045e commit 1605423
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 2 deletions.
28 changes: 28 additions & 0 deletions torchbenchmark/models/hf_Whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel
from torchbenchmark.tasks import SPEECH
import torch

class Model(HuggingFaceModel):
task = SPEECH.RECOGNITION
DEFAULT_EVAL_BSIZE = 8
DEFAULT_EVAL_CUDA_PRECISION = "fp16"

def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(name="hf_Whisper", test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
self.feature_size = 80
self.sequence_length = 3000
self.input_features = torch.randn(size=(self.batch_size, self.feature_size, self.sequence_length),device=self.device)
self.example_inputs = {"input_features": self.input_features.to(self.device), "input_ids" : self.input_features.to(self.device)}
self.model.to(self.device)

def train(self):
raise NotImplementedError("Training is not implemented.")

def eval(self):
self.model.eval()
with torch.no_grad():
self.model(self.example_inputs["input_ids"])

def enable_fp16_half(self):
self.model.half()
self.example_inputs = {"input_features": self.input_features.half().to(self.device), "input_ids" : self.input_features.half().to(self.device)}
13 changes: 13 additions & 0 deletions torchbenchmark/models/hf_Whisper/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import subprocess
import sys
import os
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model

def pip_install_requirements():
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])

if __name__ == '__main__':
pip_install_requirements()
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name)
11 changes: 11 additions & 0 deletions torchbenchmark/models/hf_Whisper/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 8
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
not_implemented:
- jit: true
- device: cpu
train_benchmark: false
train_deterministic: false
1 change: 1 addition & 0 deletions torchbenchmark/models/hf_Whisper/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
numba
3 changes: 2 additions & 1 deletion torchbenchmark/util/framework/huggingface/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchbenchmark.util.model import BenchmarkModel
from torchbenchmark.tasks import NLP
import transformers
from transformers import AutoConfig, ReformerConfig, BertConfig, LlamaConfig, GenerationConfig
from transformers import AutoConfig, ReformerConfig, BertConfig, GenerationConfig, WhisperConfig, LlamaConfig
from typing import Tuple

class_models = {
Expand All @@ -27,6 +27,7 @@
'hf_Bert': (512, 512, 'BertConfig()', 'AutoModelForMaskedLM'),
# see https://huggingface.co/bert-large-cased
'hf_Bert_large': (512, 512, 'BertConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16)', 'AutoModelForMaskedLM'),
'hf_Whisper': (1024, 1024, 'WhisperConfig()', 'AutoModelForAudioClassification'),
# default num_hidden_layers=32 but that OOMs, feel free to change this config to something more real
'llama_v2_7b_16h' : (512,512, 'LlamaConfig(num_hidden_layers=16)', 'AutoModelForCausalLM'),
}
Expand Down
3 changes: 2 additions & 1 deletion torchbenchmark/util/framework/huggingface/patch_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import subprocess
import sys
from .model_factory import class_models
from transformers import AutoConfig, ReformerConfig, BigBirdConfig, BertConfig, LlamaConfig
from transformers import AutoConfig, ReformerConfig, BigBirdConfig, BertConfig, WhisperConfig, LlamaConfig


PATCH_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "patches")

Expand Down

0 comments on commit 1605423

Please sign in to comment.