Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Alibaba-NLP/gte-Qwen2-7B-instruct embedding Model #1186

Merged
merged 2 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/accuracy-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ jobs:
run: |
cd test/srt
python3 test_eval_accuracy_large.py
timeout-minutes: 10
timeout-minutes: 20
2 changes: 1 addition & 1 deletion .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
run: |
cd test/srt
python3 run_suite.py --suite minimal
timeout-minutes: 18
timeout-minutes: 20

- name: Test Frontend Language
run: |
Expand Down
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ response = client.chat.completions.create(
max_tokens=64,
)
print(response)

# Text embedding
response = client.embeddings.create(
model="default",
input="How are you today",
)
print(response)
```

It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).
Expand Down Expand Up @@ -223,6 +230,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct

### Supported Models

**Generative Models**

- Llama / Llama 2 / Llama 3 / Llama 3.1
- Mistral / Mixtral / Mistral NeMo
- Gemma / Gemma 2
Expand All @@ -243,6 +252,12 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- ChatGLM
- InternLM 2

**Embedding Models**

- e5-mistral
- gte-Qwen2
- `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`

Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md).

#### Use Models From ModelScope
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def __init__(
trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args,
)
self.is_generation = is_generation_model(self.hf_config.architectures)

self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding
)

if server_args.context_length is not None:
self.context_len = server_args.context_length
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
context_length=server_args.context_length,
model_overide_args=model_overide_args,
)

self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
Expand Down
17 changes: 13 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def load_model(self):
else None
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures
self.model_config.hf_config.architectures, self.server_args.is_embedding
)

logger.info(
Expand Down Expand Up @@ -522,9 +522,18 @@ def forward_extend(self, batch: ScheduleBatch):
batch,
forward_mode=ForwardMode.EXTEND,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
if self.is_generation:
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
batch.input_ids,
input_metadata.positions,
input_metadata,
get_embedding=True,
)

@torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch):
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/models/llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def forward(
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaEmbeddingModel / MistralModel is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.pooler(hidden_states, input_metadata)

Expand Down
12 changes: 9 additions & 3 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata

Expand Down Expand Up @@ -275,6 +276,7 @@ def __init__(
self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

@torch.no_grad()
def forward(
Expand All @@ -283,11 +285,15 @@ def forward(
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
zhaochenyang20 marked this conversation as resolved.
Show resolved Hide resolved
else:
return self.pooler(hidden_states, input_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,13 @@ def launch_server(
start_process = start_controller_process_single
else:
start_process = start_controller_process_multi

proc_controller = mp.Process(
target=start_process,
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
)
proc_controller.start()

proc_detoken = mp.Process(
target=start_detokenizer_process,
args=(
Expand Down Expand Up @@ -515,6 +517,7 @@ def __init__(

self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)

proc = mp.Process(
target=launch_server,
args=(self.server_args, model_overide_args, pipe_writer),
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ServerArgs:
quantization: Optional[str] = None
served_model_name: Optional[str] = None
chat_template: Optional[str] = None
is_embedding: bool = False

# Port
host: str = "127.0.0.1"
Expand Down Expand Up @@ -200,6 +201,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument(
"--context-length",
type=int,
Expand Down Expand Up @@ -458,6 +464,11 @@ def check_server_args(self):
assert not (
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self.trust_remote_code = False
if "gemma-2" in self.model_path.lower():
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.disable_flashinfer = False
Expand Down
9 changes: 7 additions & 2 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,18 @@ def is_multimodal_model(model):
raise ValueError("unrecognized type")


def is_generation_model(model_architectures):
def is_generation_model(model_architectures, is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args

if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
):
return False
return True
else:
return not is_embedding


def decode_video_base64(video_base64):
Expand Down
32 changes: 16 additions & 16 deletions python/sglang/test/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

import json
import multiprocessing
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import List, Union
Expand Down Expand Up @@ -63,37 +63,35 @@ def __init__(
self,
model_path,
torch_dtype,
is_generation_model,
is_generation,
):
self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue()
self.is_generation = is_generation

self.model_proc = multiprocessing.Process(
self.in_queue = mp.Queue()
self.out_queue = mp.Queue()

self.model_proc = mp.Process(
target=self.start_model_process,
args=(
self.in_queue,
self.out_queue,
model_path,
torch_dtype,
is_generation_model,
),
)
self.model_proc.start()

def start_model_process(
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
):
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
torch_dtype=torch_dtype,
)

self.is_generation_model = is_generation_model

if self.is_generation_model:
if self.is_generation:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
low_cpu_mem_usage=True,
).cuda()
else:
Expand All @@ -107,7 +105,7 @@ def start_model_process(
while True:
prompts, max_new_tokens = in_queue.get()
if prompts is not None:
if self.is_generation_model:
if self.is_generation:
output_strs = []
prefill_logprobs = []
for p in prompts:
Expand Down Expand Up @@ -171,25 +169,27 @@ def __init__(
self,
model_path,
torch_dtype,
is_generation_model,
is_generation,
tp_size=1,
port=5157,
):
self.is_generation_model = is_generation_model
self.is_generation = is_generation
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=0.7,
trust_remote_code=False,
is_embedding=not self.is_generation,
)

def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
):
if self.is_generation_model:
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
Expand Down
28 changes: 13 additions & 15 deletions test/srt/models/test_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities

MODELS = [("intfloat/e5-mistral-7b-instruct", 1, 0.2)]
MODELS = [
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
]
TORCH_DTYPES = [torch.float16]


Expand All @@ -32,22 +35,20 @@ def assert_close_prefill_logits(
model_path,
tp_size,
torch_dtype,
long_context_tolerance,
prefill_tolerance,
) -> None:
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation_model=False
model_path, torch_dtype=torch_dtype, is_generation=False
) as hf_runner:
hf_outputs = hf_runner.forward(prompts)

with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation_model=False,
is_generation=False,
) as srt_runner:
srt_outputs = srt_runner.forward(
prompts,
)
srt_outputs = srt_runner.forward(prompts)

for i in range(len(prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
Expand All @@ -57,18 +58,15 @@ def assert_close_prefill_logits(
print("similarity diff", abs(similarity - 1))

if len(prompts[i]) <= 1000:
tolerance = 1e-5
else:
tolerance = long_context_tolerance
assert torch.all(
abs(similarity - 1) < tolerance
), "embeddings are not all close"
assert torch.all(
abs(similarity - 1) < prefill_tolerance
), "embeddings are not all close"

def test_prefill_logits(self):
for model, tp_size, long_context_tolerance in MODELS:
for model, tp_size, prefill_tolerance in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype, long_context_tolerance
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
)


Expand Down
Loading
Loading