Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Falcon-7B HuggingFace model #1758

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions torchbenchmark/models/hf_Falcon_7b/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from torchbenchmark.tasks import NLP
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel

class Model(HuggingFaceModel):
task = NLP.LANGUAGE_MODELING
# Published training batch size per GPU is 6: see https://huggingface.co/tiiuae/falcon-7b/blob/main/README.md#:~:text=Batch%20size,tokens%20ramp%2Dup
DEFAULT_TRAIN_BSIZE = 6
DEFAULT_EVAL_BSIZE = 1

def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(name="hf_Falcon_7b", test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)

def eval(self):
if (self.device == "cpu"):
raise NotImplementedError("Falcon model is too slow on CPU - skip CPU test.")
super().eval()
14 changes: 14 additions & 0 deletions torchbenchmark/models/hf_Falcon_7b/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

import subprocess
import sys
import os
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model

def pip_install_requirements():
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])

if __name__ == '__main__':
pip_install_requirements()
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name, trust_remote_code=True)
10 changes: 10 additions & 0 deletions torchbenchmark/models/hf_Falcon_7b/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 16
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
not_implemented:
- jit: true
train_benchmark: false
train_deterministic: false
2 changes: 2 additions & 0 deletions torchbenchmark/models/hf_Falcon_7b/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
sentencepiece
datasets
6 changes: 5 additions & 1 deletion torchbenchmark/util/framework/huggingface/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
'hf_Bert': (512, 512, 'BertConfig()', 'AutoModelForMaskedLM'),
# see https://huggingface.co/bert-large-cased
'hf_Bert_large': (512, 512, 'BertConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16)', 'AutoModelForMaskedLM'),
'hf_Falcon_7b' : (512, 512, 'AutoConfig.from_pretrained("tiiuae/falcon-7b", trust_remote_code=True)', 'AutoModelForCausalLM'),
}

cpu_input_slice = {
Expand Down Expand Up @@ -77,7 +78,10 @@ def __init__(self, name, test, device, jit=False, batch_size=None, extra_args=[]
# silence "config.num_buckets is not set. Setting config.num_buckets to 128"
config.num_buckets = 128
class_ctor = getattr(transformers, class_models[name][3])
self.model = class_ctor.from_config(config).to(device)
kwargs = {}
if name == "hf_Falcon_7b":
kwargs["trust_remote_code"] = True
self.model = class_ctor.from_config(config, **kwargs).to(device)
self.optimizer = optim.Adam(
self.model.parameters(),
lr=0.001,
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/util/framework/huggingface/patch_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

PATCH_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "patches")

def cache_model(name: str):
def cache_model(name: str, **kwargs):
import transformers
model_config = eval(class_models[name][2])
model_ctor = getattr(transformers, class_models[name][3])
model_ctor.from_config(model_config)
model_ctor.from_config(model_config, **kwargs)

def patch_transformers():
import transformers
Expand Down
Loading