-
-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core][Model] Add simple_model_runner and a new model XLMRobertaForSe…
…quenceClassification through multimodal interface
- Loading branch information
zixiao
committed
Aug 19, 2024
1 parent
1a36287
commit 3457cb6
Showing
26 changed files
with
1,704 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import List, Tuple, Union | ||
|
||
import torch | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||
|
||
model_name_or_path = "BAAI/bge-reranker-base" | ||
cache_dir = None | ||
max_length = 512 | ||
|
||
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]] = \ | ||
[("hello world", "nice to meet you"), ("head north", "head south")] | ||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, | ||
cache_dir=cache_dir) | ||
# XLMRobertaForSequenceClassification | ||
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, | ||
cache_dir=cache_dir) | ||
model = model.to("cuda") | ||
model.eval() | ||
|
||
inputs = tokenizer( | ||
sentence_pairs, | ||
padding=True, | ||
truncation=True, | ||
return_tensors='pt', | ||
max_length=max_length, | ||
).to("cuda") | ||
|
||
all_scores = [] | ||
with torch.no_grad(): | ||
logits = model(**inputs, return_dict=True).logits | ||
scores = logits.view(-1, ).float() | ||
all_scores.extend(scores.cpu().numpy().tolist()) | ||
print(all_scores) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import List, Tuple, Union | ||
|
||
from transformers import AutoTokenizer | ||
|
||
from vllm import LLM | ||
|
||
model = "BAAI/bge-reranker-base" | ||
llm = LLM(model=model, tensor_parallel_size=1) | ||
|
||
prompt = "this is a useless prompt." | ||
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]] = \ | ||
[("hello world", "nice to meet you"), ("head north", "head south")] | ||
tokenizer = AutoTokenizer.from_pretrained(model, cache_dir=None) | ||
|
||
inputs = tokenizer( | ||
sentence_pairs, | ||
padding=True, | ||
truncation=True, | ||
return_tensors='pt', | ||
max_length=512, | ||
).to("cuda") | ||
outputs = llm.process([{ | ||
"prompt": prompt, | ||
"multi_modal_data": { | ||
"xlmroberta": inputs, | ||
} | ||
}], | ||
use_tqdm=False) | ||
|
||
for output in outputs: | ||
print(output.outputs.result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import List, Optional, Tuple, Type, Union | ||
|
||
import pytest | ||
import torch | ||
from transformers import AutoTokenizer | ||
|
||
from ..conftest import HfRunner, VllmRunner | ||
|
||
models = ["BAAI/bge-reranker-base"] | ||
|
||
|
||
def run_test( | ||
hf_runner: Type[HfRunner], | ||
vllm_runner: Type[VllmRunner], | ||
model: str, | ||
*, | ||
dtype: str, | ||
tensor_parallel_size: int, | ||
distributed_executor_backend: Optional[str] = None, | ||
): | ||
"""Inference result should be the same between hf and vllm.""" | ||
|
||
prompt = "this is a useless prompt." | ||
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]] = \ | ||
[("hello world", "nice to meet you"), ("head north", "head south")] | ||
tokenizer = AutoTokenizer.from_pretrained(model, cache_dir=None) | ||
inputs = tokenizer( | ||
sentence_pairs, | ||
padding=True, | ||
truncation=True, | ||
return_tensors='pt', | ||
max_length=512, | ||
).to("cuda") | ||
|
||
with vllm_runner(model, | ||
dtype=dtype, | ||
max_model_len=512, | ||
tensor_parallel_size=tensor_parallel_size, | ||
distributed_executor_backend=distributed_executor_backend, | ||
enforce_eager=True) as vllm_model: | ||
vllm_outputs = vllm_model.process([{ | ||
"prompt": prompt, | ||
"multi_modal_data": { | ||
"xlmroberta": inputs, | ||
} | ||
}]) | ||
|
||
with hf_runner(model, dtype=dtype, is_simple_model=True) as hf_model: | ||
hf_outputs = hf_model.process(**inputs) | ||
|
||
print(vllm_outputs[0].outputs.result, hf_outputs.logits.view(-1, )) | ||
assert torch.allclose(vllm_outputs[0].outputs.result, | ||
hf_outputs.logits.view(-1, )) | ||
|
||
|
||
@pytest.mark.parametrize("model", models) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_models(hf_runner, vllm_runner, model, dtype: str) -> None: | ||
run_test( | ||
hf_runner, | ||
vllm_runner, | ||
model, | ||
dtype=dtype, | ||
tensor_parallel_size=1, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.