Skip to content

Commit

Permalink
Add MPT-7B-instruct HuggingFace Model (#1773)
Browse files Browse the repository at this point in the history
Summary:
PR to add https://huggingface.co/mosaicml/mpt-7b-instruct in torchbenchmark.
Running mpt-7b-instruct requires ```trust_remote_code``` to be passed

Pull Request resolved: #1773

Reviewed By: msaroufim

Differential Revision: D48031826

Pulled By: xuzhao9

fbshipit-source-id: d3cada22894544a88409ef11ccd4b0b101507a2a
  • Loading branch information
apsonawane authored and facebook-github-bot committed Aug 3, 2023
1 parent e7ca300 commit 0b7147f
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 3 deletions.
14 changes: 14 additions & 0 deletions torchbenchmark/models/hf_MPT_7b_instruct/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from torchbenchmark.tasks import NLP
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel

class Model(HuggingFaceModel):
task = NLP.LANGUAGE_MODELING
# https://huggingface.co/mosaicml/mpt-7b
DEFAULT_TRAIN_BSIZE = 4
DEFAULT_EVAL_BSIZE = 1

def __init__(self, test, device, batch_size=None, extra_args=[]):
super().__init__(name="hf_MPT_7b_instruct", test=test, device=device, batch_size=batch_size, extra_args=extra_args)

def eval(self):
super().eval()
7 changes: 7 additions & 0 deletions torchbenchmark/models/hf_MPT_7b_instruct/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model

if __name__ == '__main__':
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name, trust_remote_code=True)
8 changes: 8 additions & 0 deletions torchbenchmark/models/hf_MPT_7b_instruct/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 1
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
6 changes: 5 additions & 1 deletion torchbenchmark/util/framework/huggingface/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
'hf_Whisper': (1024, 1024, 'WhisperConfig()', 'AutoModelForAudioClassification'),
# default num_hidden_layers=32 but that OOMs, feel free to change this config to something more real
'llama_v2_7b_16h' : (512,512, 'LlamaConfig(num_hidden_layers=16)', 'AutoModelForCausalLM'),
'hf_MPT_7b_instruct': (512, 512, 'AutoConfig.from_pretrained("mosaicml/mpt-7b-instruct", trust_remote_code=True)', 'AutoModelForCausalLM'),
'llama_v2_7b' : (512,512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")', 'AutoModelForCausalLM'),
'llama_v2_13b' : (512,512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-13b-hf")', 'AutoModelForCausalLM'),
'llama_v2_70b' : (512, 512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-70b-hf")', 'AutoModelForMaskedLM'),
Expand Down Expand Up @@ -83,7 +84,10 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]):
# silence "config.num_buckets is not set. Setting config.num_buckets to 128"
config.num_buckets = 128
class_ctor = getattr(transformers, class_models[name][3])
self.model = class_ctor.from_config(config).to(device)
kwargs = {}
if name == "hf_Falcon_7b" or name == "hf_MPT_7b_instruct":
kwargs["trust_remote_code"] = True
self.model = class_ctor.from_config(config, **kwargs).to(device)
self.optimizer = optim.Adam(
self.model.parameters(),
lr=0.001,
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/util/framework/huggingface/patch_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

PATCH_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "patches")

def cache_model(name: str):
def cache_model(name: str, **kwargs):
import transformers
model_config = eval(class_models[name][2])
model_ctor = getattr(transformers, class_models[name][3])
model_ctor.from_config(model_config)
model_ctor.from_config(model_config, **kwargs)

def patch_transformers():
import transformers
Expand Down

0 comments on commit 0b7147f

Please sign in to comment.